diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7e48edc1fcbd340725978ff41afa3ebd192d1db5
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+---
+title: Video LLaMA
+emoji: 🚀
+colorFrom: purple
+colorTo: gray
+sdk: gradio
+sdk_version: 3.29.0
+app_file: app.py
+pinned: false
+license: other
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f06d76f82f514275bcf52ac69be076709b4845c5
--- /dev/null
+++ b/app.py
@@ -0,0 +1,192 @@
+"""
+Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py
+"""
+import argparse
+import os
+import random
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+import gradio as gr
+
+from video_llama.common.config import Config
+from video_llama.common.dist_utils import get_rank
+from video_llama.common.registry import registry
+from video_llama.conversation.conversation_video import Chat, Conversation, default_conversation,SeparatorStyle
+import decord
+decord.bridge.set_bridge('torch')
+
+#%%
+# imports modules for registration
+from video_llama.datasets.builders import *
+from video_llama.models import *
+from video_llama.processors import *
+from video_llama.runners import *
+from video_llama.tasks import *
+
+#%%
+def parse_args():
+ parser = argparse.ArgumentParser(description="Demo")
+ parser.add_argument("--cfg-path", default='eval_configs/video_llama_eval.yaml', help="path to configuration file.")
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
+ parser.add_argument(
+ "--options",
+ nargs="+",
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def setup_seeds(config):
+ seed = config.run_cfg.seed + get_rank()
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+
+
+# ========================================
+# Model Initialization
+# ========================================
+
+print('Initializing Chat')
+args = parse_args()
+cfg = Config(args)
+
+model_config = cfg.model_cfg
+model_config.device_8bit = args.gpu_id
+model_cls = registry.get_model_class(model_config.arch)
+model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
+
+vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
+vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
+chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
+print('Initialization Finished')
+
+# ========================================
+# Gradio Setting
+# ========================================
+
+def gradio_reset(chat_state, img_list):
+ if chat_state is not None:
+ chat_state.messages = []
+ if img_list is not None:
+ img_list = []
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
+
+def upload_imgorvideo(gr_video, gr_img, text_input, chat_state):
+ if gr_img is None and gr_video is None:
+ return None, None, None, gr.update(interactive=True), chat_state, None
+ elif gr_img is not None and gr_video is None:
+ print(gr_img)
+ chat_state = Conversation(
+ system= "You are able to understand the visual content that the user provides."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ )
+ img_list = []
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
+ elif gr_video is not None and gr_img is None:
+ print(gr_video)
+ chat_state = default_conversation.copy()
+ chat_state = Conversation(
+ system= "You are able to understand the visual content that the user provides."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ )
+ img_list = []
+ llm_message = chat.upload_video(gr_video, chat_state, img_list)
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
+ else:
+ # img_list = []
+ return gr.update(interactive=False), gr.update(interactive=False, placeholder='Currently, only one input is supported'), gr.update(value="Currently, only one input is supported", interactive=False), chat_state, None
+
+def gradio_ask(user_message, chatbot, chat_state):
+ if len(user_message) == 0:
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
+ chat.ask(user_message, chat_state)
+ chatbot = chatbot + [[user_message, None]]
+ return '', chatbot, chat_state
+
+
+def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
+ llm_message = chat.answer(conv=chat_state,
+ img_list=img_list,
+ num_beams=num_beams,
+ temperature=temperature,
+ max_new_tokens=300,
+ max_length=2000)[0]
+ chatbot[-1][1] = llm_message
+ print(chat_state.get_prompt())
+ print(chat_state)
+ return chatbot, chat_state, img_list
+
+title = """
Demo of Video-LLaMA
"""
+description = """This is the demo of Video-LLaMA. Upload your images/videos and start chatting!
"""
+
+
+#TODO show examples below
+
+with gr.Blocks() as demo:
+ gr.Markdown(title)
+ gr.Markdown(description)
+
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ video = gr.Video()
+ image = gr.Image(type="pil")
+
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
+ clear = gr.Button("Restart")
+
+ num_beams = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=1,
+ step=1,
+ interactive=True,
+ label="beam search numbers)",
+ )
+
+ temperature = gr.Slider(
+ minimum=0.1,
+ maximum=2.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+
+ with gr.Column():
+ chat_state = gr.State()
+ img_list = gr.State()
+ chatbot = gr.Chatbot(label='Video-LLaMA')
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image/video first', interactive=False)
+
+
+ upload_button.click(upload_imgorvideo, [video, image, text_input, chat_state], [video, image, text_input, upload_button, chat_state, img_list])
+
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
+ )
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, video, image, text_input, upload_button, chat_state, img_list], queue=False)
+
+demo.launch(share=False, enable_queue=False)
+
+# %%
diff --git a/ckpt/blip2_pretrained_flant5xxl.pth b/ckpt/blip2_pretrained_flant5xxl.pth
new file mode 100644
index 0000000000000000000000000000000000000000..70fa1bcf8e62b78c270d100025736323df52e52c
--- /dev/null
+++ b/ckpt/blip2_pretrained_flant5xxl.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b3839ea6c617f315ead9bf4036bbb0f0cf6bf62695ecfc14968ea626af03a29
+size 433481467
diff --git a/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth b/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth
new file mode 100644
index 0000000000000000000000000000000000000000..9e48012e9ecebc4811525b863b539295d358971a
--- /dev/null
+++ b/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc4b32437c90df51bc3faa29deaa9b25ab77e1707ac79066f17ae3193ebe8bfc
+size 1527692539
diff --git a/ckpt/finetune-vicuna7b-v2.pth b/ckpt/finetune-vicuna7b-v2.pth
new file mode 100644
index 0000000000000000000000000000000000000000..ffcf94d3949e0d60d0c6d79392fbded8ea1aabe0
--- /dev/null
+++ b/ckpt/finetune-vicuna7b-v2.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0680ad8eb14c2a3273b7be71309ab6b06c9f426e87ad4675a903371fe0fa8162
+size 265436777
diff --git a/ckpt/pretrain-billa7b-zh.pth b/ckpt/pretrain-billa7b-zh.pth
new file mode 100644
index 0000000000000000000000000000000000000000..e70ac5f43911ddecf2052cc1be40e5cce33f1c2f
--- /dev/null
+++ b/ckpt/pretrain-billa7b-zh.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f50a51db3055e1be6461f6dec833fbbbba28650287d26c8787664c8ee31dcf0f
+size 265435689
diff --git a/eval_configs/video_llama_eval.yaml b/eval_configs/video_llama_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..32be978a877ecdf08594f843871c769341b16903
--- /dev/null
+++ b/eval_configs/video_llama_eval.yaml
@@ -0,0 +1,32 @@
+model:
+ arch: video_llama
+ model_type: pretrain_vicuna
+ freeze_vit: True
+ freeze_qformer: True
+ max_txt_len: 512
+ end_sym: "###"
+ low_resource: False
+
+
+ llama_model: "DAMO-NLP-SG/vicuna-7b"
+
+ fusion_head_layers: 2
+ max_frame_pos: 32
+ fusion_header_type: "seqTransf"
+
+ ckpt: 'ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth'
+ q_former_model: 'ckpt/blip2_pretrained_flant5xxl.pth'
+
+datasets:
+ webvid:
+ vis_processor:
+ train:
+ name: "alpro_video_eval"
+ n_frms: 8
+ image_size: 224
+ text_processor:
+ train:
+ name: "blip_caption"
+
+run:
+ task: video_text_pretrain
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a2d2ac49c79d36d9c4ca9262763499cf4bba2c81
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+transformers==4.28.0
+tqdm
+decord
+timm
+einops
+opencv_python
+torchvision
+
+salesforce-lavis
+bitsandbytes
+accelerate
\ No newline at end of file
diff --git a/video_llama/__init__.py b/video_llama/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90b38d4c6f7611646be527265cd0bd868d4f1f81
--- /dev/null
+++ b/video_llama/__init__.py
@@ -0,0 +1,31 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+import sys
+
+from omegaconf import OmegaConf
+
+from video_llama.common.registry import registry
+
+from video_llama.datasets.builders import *
+from video_llama.models import *
+from video_llama.processors import *
+from video_llama.tasks import *
+
+
+root_dir = os.path.dirname(os.path.abspath(__file__))
+default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
+
+registry.register_path("library_root", root_dir)
+repo_root = os.path.join(root_dir, "..")
+registry.register_path("repo_root", repo_root)
+cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
+registry.register_path("cache_root", cache_root)
+
+registry.register("MAX_INT", sys.maxsize)
+registry.register("SPLIT_NAMES", ["train", "val", "test"])
diff --git a/video_llama/app.py b/video_llama/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f06d76f82f514275bcf52ac69be076709b4845c5
--- /dev/null
+++ b/video_llama/app.py
@@ -0,0 +1,192 @@
+"""
+Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py
+"""
+import argparse
+import os
+import random
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+import gradio as gr
+
+from video_llama.common.config import Config
+from video_llama.common.dist_utils import get_rank
+from video_llama.common.registry import registry
+from video_llama.conversation.conversation_video import Chat, Conversation, default_conversation,SeparatorStyle
+import decord
+decord.bridge.set_bridge('torch')
+
+#%%
+# imports modules for registration
+from video_llama.datasets.builders import *
+from video_llama.models import *
+from video_llama.processors import *
+from video_llama.runners import *
+from video_llama.tasks import *
+
+#%%
+def parse_args():
+ parser = argparse.ArgumentParser(description="Demo")
+ parser.add_argument("--cfg-path", default='eval_configs/video_llama_eval.yaml', help="path to configuration file.")
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
+ parser.add_argument(
+ "--options",
+ nargs="+",
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def setup_seeds(config):
+ seed = config.run_cfg.seed + get_rank()
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+
+
+# ========================================
+# Model Initialization
+# ========================================
+
+print('Initializing Chat')
+args = parse_args()
+cfg = Config(args)
+
+model_config = cfg.model_cfg
+model_config.device_8bit = args.gpu_id
+model_cls = registry.get_model_class(model_config.arch)
+model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
+
+vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
+vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
+chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
+print('Initialization Finished')
+
+# ========================================
+# Gradio Setting
+# ========================================
+
+def gradio_reset(chat_state, img_list):
+ if chat_state is not None:
+ chat_state.messages = []
+ if img_list is not None:
+ img_list = []
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
+
+def upload_imgorvideo(gr_video, gr_img, text_input, chat_state):
+ if gr_img is None and gr_video is None:
+ return None, None, None, gr.update(interactive=True), chat_state, None
+ elif gr_img is not None and gr_video is None:
+ print(gr_img)
+ chat_state = Conversation(
+ system= "You are able to understand the visual content that the user provides."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ )
+ img_list = []
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
+ elif gr_video is not None and gr_img is None:
+ print(gr_video)
+ chat_state = default_conversation.copy()
+ chat_state = Conversation(
+ system= "You are able to understand the visual content that the user provides."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ )
+ img_list = []
+ llm_message = chat.upload_video(gr_video, chat_state, img_list)
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
+ else:
+ # img_list = []
+ return gr.update(interactive=False), gr.update(interactive=False, placeholder='Currently, only one input is supported'), gr.update(value="Currently, only one input is supported", interactive=False), chat_state, None
+
+def gradio_ask(user_message, chatbot, chat_state):
+ if len(user_message) == 0:
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
+ chat.ask(user_message, chat_state)
+ chatbot = chatbot + [[user_message, None]]
+ return '', chatbot, chat_state
+
+
+def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
+ llm_message = chat.answer(conv=chat_state,
+ img_list=img_list,
+ num_beams=num_beams,
+ temperature=temperature,
+ max_new_tokens=300,
+ max_length=2000)[0]
+ chatbot[-1][1] = llm_message
+ print(chat_state.get_prompt())
+ print(chat_state)
+ return chatbot, chat_state, img_list
+
+title = """Demo of Video-LLaMA
"""
+description = """This is the demo of Video-LLaMA. Upload your images/videos and start chatting!
"""
+
+
+#TODO show examples below
+
+with gr.Blocks() as demo:
+ gr.Markdown(title)
+ gr.Markdown(description)
+
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ video = gr.Video()
+ image = gr.Image(type="pil")
+
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
+ clear = gr.Button("Restart")
+
+ num_beams = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=1,
+ step=1,
+ interactive=True,
+ label="beam search numbers)",
+ )
+
+ temperature = gr.Slider(
+ minimum=0.1,
+ maximum=2.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+
+ with gr.Column():
+ chat_state = gr.State()
+ img_list = gr.State()
+ chatbot = gr.Chatbot(label='Video-LLaMA')
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image/video first', interactive=False)
+
+
+ upload_button.click(upload_imgorvideo, [video, image, text_input, chat_state], [video, image, text_input, upload_button, chat_state, img_list])
+
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
+ )
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, video, image, text_input, upload_button, chat_state, img_list], queue=False)
+
+demo.launch(share=False, enable_queue=False)
+
+# %%
diff --git a/video_llama/ckpt/blip2_pretrained_flant5xxl.pth b/video_llama/ckpt/blip2_pretrained_flant5xxl.pth
new file mode 100644
index 0000000000000000000000000000000000000000..70fa1bcf8e62b78c270d100025736323df52e52c
--- /dev/null
+++ b/video_llama/ckpt/blip2_pretrained_flant5xxl.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b3839ea6c617f315ead9bf4036bbb0f0cf6bf62695ecfc14968ea626af03a29
+size 433481467
diff --git a/video_llama/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth b/video_llama/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth
new file mode 100644
index 0000000000000000000000000000000000000000..9c45fc8ea3c715c39e1d18de5ac6c63b22e9b224
--- /dev/null
+++ b/video_llama/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:46af76d307c14d28c56534e4bf8654343e5512aa1285fc1c1fdb5728c418e7ca
+size 623104000
diff --git a/video_llama/ckpt/pretrain-billa7b-zh.pth b/video_llama/ckpt/pretrain-billa7b-zh.pth
new file mode 100644
index 0000000000000000000000000000000000000000..e70ac5f43911ddecf2052cc1be40e5cce33f1c2f
--- /dev/null
+++ b/video_llama/ckpt/pretrain-billa7b-zh.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f50a51db3055e1be6461f6dec833fbbbba28650287d26c8787664c8ee31dcf0f
+size 265435689
diff --git a/video_llama/common/__init__.py b/video_llama/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/video_llama/common/config.py b/video_llama/common/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a718644414610b2f10a01a51ebe58c14ff5f30
--- /dev/null
+++ b/video_llama/common/config.py
@@ -0,0 +1,468 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import json
+from typing import Dict
+
+from omegaconf import OmegaConf
+from video_llama.common.registry import registry
+
+
+class Config:
+ def __init__(self, args):
+ self.config = {}
+
+ self.args = args
+
+ # Register the config and configuration for setup
+ registry.register("configuration", self)
+
+ user_config = self._build_opt_list(self.args.options)
+
+ config = OmegaConf.load(self.args.cfg_path)
+
+ runner_config = self.build_runner_config(config)
+ model_config = self.build_model_config(config, **user_config)
+ dataset_config = self.build_dataset_config(config)
+
+ # Validate the user-provided runner configuration
+ # model and dataset configuration are supposed to be validated by the respective classes
+ # [TODO] validate the model/dataset configuration
+ # self._validate_runner_config(runner_config)
+
+ # Override the default configuration with user options.
+ self.config = OmegaConf.merge(
+ runner_config, model_config, dataset_config, user_config
+ )
+
+ def _validate_runner_config(self, runner_config):
+ """
+ This method validates the configuration, such that
+ 1) all the user specified options are valid;
+ 2) no type mismatches between the user specified options and the config.
+ """
+ runner_config_validator = create_runner_config_validator()
+ runner_config_validator.validate(runner_config)
+
+ def _build_opt_list(self, opts):
+ opts_dot_list = self._convert_to_dot_list(opts)
+ return OmegaConf.from_dotlist(opts_dot_list)
+
+ @staticmethod
+ def build_model_config(config, **kwargs):
+ model = config.get("model", None)
+ assert model is not None, "Missing model configuration file."
+
+ model_cls = registry.get_model_class(model.arch)
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
+
+ model_type = kwargs.get("model.model_type", None)
+ if not model_type:
+ model_type = model.get("model_type", None)
+ # else use the model type selected by user.
+
+ assert model_type is not None, "Missing model_type."
+
+ model_config_path = model_cls.default_config_path(model_type=model_type)
+
+ model_config = OmegaConf.create()
+ # hierarchy override, customized config > default config
+ model_config = OmegaConf.merge(
+ model_config,
+ OmegaConf.load(model_config_path),
+ {"model": config["model"]},
+ )
+
+ return model_config
+
+ @staticmethod
+ def build_runner_config(config):
+ return {"run": config.run}
+
+ @staticmethod
+ def build_dataset_config(config):
+ datasets = config.get("datasets", None)
+ if datasets is None:
+ raise KeyError(
+ "Expecting 'datasets' as the root key for dataset configuration."
+ )
+
+ dataset_config = OmegaConf.create()
+
+ for dataset_name in datasets:
+ builder_cls = registry.get_builder_class(dataset_name)
+
+ dataset_config_type = datasets[dataset_name].get("type", "default")
+ dataset_config_path = builder_cls.default_config_path(
+ type=dataset_config_type
+ )
+
+ # hierarchy override, customized config > default config
+ dataset_config = OmegaConf.merge(
+ dataset_config,
+ OmegaConf.load(dataset_config_path),
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
+ )
+
+ return dataset_config
+
+ def _convert_to_dot_list(self, opts):
+ if opts is None:
+ opts = []
+
+ if len(opts) == 0:
+ return opts
+
+ has_equal = opts[0].find("=") != -1
+
+ if has_equal:
+ return opts
+
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
+
+ def get_config(self):
+ return self.config
+
+ @property
+ def run_cfg(self):
+ return self.config.run
+
+ @property
+ def datasets_cfg(self):
+ return self.config.datasets
+
+ @property
+ def model_cfg(self):
+ return self.config.model
+
+ def pretty_print(self):
+ logging.info("\n===== Running Parameters =====")
+ logging.info(self._convert_node_to_json(self.config.run))
+
+ logging.info("\n====== Dataset Attributes ======")
+ datasets = self.config.datasets
+
+ for dataset in datasets:
+ if dataset in self.config.datasets:
+ logging.info(f"\n======== {dataset} =======")
+ dataset_config = self.config.datasets[dataset]
+ logging.info(self._convert_node_to_json(dataset_config))
+ else:
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
+
+ logging.info(f"\n====== Model Attributes ======")
+ logging.info(self._convert_node_to_json(self.config.model))
+
+ def _convert_node_to_json(self, node):
+ container = OmegaConf.to_container(node, resolve=True)
+ return json.dumps(container, indent=4, sort_keys=True)
+
+ def to_dict(self):
+ return OmegaConf.to_container(self.config)
+
+
+def node_to_dict(node):
+ return OmegaConf.to_container(node)
+
+
+class ConfigValidator:
+ """
+ This is a preliminary implementation to centralize and validate the configuration.
+ May be altered in the future.
+
+ A helper class to validate configurations from yaml file.
+
+ This serves the following purposes:
+ 1. Ensure all the options in the yaml are defined, raise error if not.
+ 2. when type mismatches are found, the validator will raise an error.
+ 3. a central place to store and display helpful messages for supported configurations.
+
+ """
+
+ class _Argument:
+ def __init__(self, name, choices=None, type=None, help=None):
+ self.name = name
+ self.val = None
+ self.choices = choices
+ self.type = type
+ self.help = help
+
+ def __str__(self):
+ s = f"{self.name}={self.val}"
+ if self.type is not None:
+ s += f", ({self.type})"
+ if self.choices is not None:
+ s += f", choices: {self.choices}"
+ if self.help is not None:
+ s += f", ({self.help})"
+ return s
+
+ def __init__(self, description):
+ self.description = description
+
+ self.arguments = dict()
+
+ self.parsed_args = None
+
+ def __getitem__(self, key):
+ assert self.parsed_args is not None, "No arguments parsed yet."
+
+ return self.parsed_args[key]
+
+ def __str__(self) -> str:
+ return self.format_help()
+
+ def add_argument(self, *args, **kwargs):
+ """
+ Assume the first argument is the name of the argument.
+ """
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
+
+ def validate(self, config=None):
+ """
+ Convert yaml config (dict-like) to list, required by argparse.
+ """
+ for k, v in config.items():
+ assert (
+ k in self.arguments
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
+
+ if self.arguments[k].type is not None:
+ try:
+ self.arguments[k].val = self.arguments[k].type(v)
+ except ValueError:
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
+
+ if self.arguments[k].choices is not None:
+ assert (
+ v in self.arguments[k].choices
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
+
+ return config
+
+ def format_arguments(self):
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
+
+ def format_help(self):
+ # description + key-value pair string for each argument
+ help_msg = str(self.description)
+ return help_msg + ", available arguments: " + self.format_arguments()
+
+ def print_help(self):
+ # display help message
+ print(self.format_help())
+
+
+def create_runner_config_validator():
+ validator = ConfigValidator(description="Runner configurations")
+
+ validator.add_argument(
+ "runner",
+ type=str,
+ choices=["runner_base", "runner_iter"],
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
+ runner runs based on iters. Default: runner_base""",
+ )
+ # add argumetns for training dataset ratios
+ validator.add_argument(
+ "train_dataset_ratios",
+ type=Dict[str, float],
+ help="""Ratios of training dataset. This is used in iteration-based runner.
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
+ Default: None""",
+ )
+ validator.add_argument(
+ "max_iters",
+ type=float,
+ help="Maximum number of iterations to run.",
+ )
+ validator.add_argument(
+ "max_epoch",
+ type=int,
+ help="Maximum number of epochs to run.",
+ )
+ # add arguments for iters_per_inner_epoch
+ validator.add_argument(
+ "iters_per_inner_epoch",
+ type=float,
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
+ )
+ lr_scheds_choices = registry.list_lr_schedulers()
+ validator.add_argument(
+ "lr_sched",
+ type=str,
+ choices=lr_scheds_choices,
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
+ )
+ task_choices = registry.list_tasks()
+ validator.add_argument(
+ "task",
+ type=str,
+ choices=task_choices,
+ help="Task to use, from {}".format(task_choices),
+ )
+ # add arguments for init_lr
+ validator.add_argument(
+ "init_lr",
+ type=float,
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
+ )
+ # add arguments for min_lr
+ validator.add_argument(
+ "min_lr",
+ type=float,
+ help="Minimum learning rate (after decay).",
+ )
+ # add arguments for warmup_lr
+ validator.add_argument(
+ "warmup_lr",
+ type=float,
+ help="Starting learning rate for warmup.",
+ )
+ # add arguments for learning rate decay rate
+ validator.add_argument(
+ "lr_decay_rate",
+ type=float,
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
+ )
+ # add arguments for weight decay
+ validator.add_argument(
+ "weight_decay",
+ type=float,
+ help="Weight decay rate.",
+ )
+ # add arguments for training batch size
+ validator.add_argument(
+ "batch_size_train",
+ type=int,
+ help="Training batch size.",
+ )
+ # add arguments for evaluation batch size
+ validator.add_argument(
+ "batch_size_eval",
+ type=int,
+ help="Evaluation batch size, including validation and testing.",
+ )
+ # add arguments for number of workers for data loading
+ validator.add_argument(
+ "num_workers",
+ help="Number of workers for data loading.",
+ )
+ # add arguments for warm up steps
+ validator.add_argument(
+ "warmup_steps",
+ type=int,
+ help="Number of warmup steps. Required if a warmup schedule is used.",
+ )
+ # add arguments for random seed
+ validator.add_argument(
+ "seed",
+ type=int,
+ help="Random seed.",
+ )
+ # add arguments for output directory
+ validator.add_argument(
+ "output_dir",
+ type=str,
+ help="Output directory to save checkpoints and logs.",
+ )
+ # add arguments for whether only use evaluation
+ validator.add_argument(
+ "evaluate",
+ help="Whether to only evaluate the model. If true, training will not be performed.",
+ )
+ # add arguments for splits used for training, e.g. ["train", "val"]
+ validator.add_argument(
+ "train_splits",
+ type=list,
+ help="Splits to use for training.",
+ )
+ # add arguments for splits used for validation, e.g. ["val"]
+ validator.add_argument(
+ "valid_splits",
+ type=list,
+ help="Splits to use for validation. If not provided, will skip the validation.",
+ )
+ # add arguments for splits used for testing, e.g. ["test"]
+ validator.add_argument(
+ "test_splits",
+ type=list,
+ help="Splits to use for testing. If not provided, will skip the testing.",
+ )
+ # add arguments for accumulating gradient for iterations
+ validator.add_argument(
+ "accum_grad_iters",
+ type=int,
+ help="Number of iterations to accumulate gradient for.",
+ )
+
+ # ====== distributed training ======
+ validator.add_argument(
+ "device",
+ type=str,
+ choices=["cpu", "cuda"],
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
+ )
+ validator.add_argument(
+ "world_size",
+ type=int,
+ help="Number of processes participating in the job.",
+ )
+ validator.add_argument("dist_url", type=str)
+ validator.add_argument("distributed", type=bool)
+ # add arguments to opt using distributed sampler during evaluation or not
+ validator.add_argument(
+ "use_dist_eval_sampler",
+ type=bool,
+ help="Whether to use distributed sampler during evaluation or not.",
+ )
+
+ # ====== task specific ======
+ # generation task specific arguments
+ # add arguments for maximal length of text output
+ validator.add_argument(
+ "max_len",
+ type=int,
+ help="Maximal length of text output.",
+ )
+ # add arguments for minimal length of text output
+ validator.add_argument(
+ "min_len",
+ type=int,
+ help="Minimal length of text output.",
+ )
+ # add arguments number of beams
+ validator.add_argument(
+ "num_beams",
+ type=int,
+ help="Number of beams used for beam search.",
+ )
+
+ # vqa task specific arguments
+ # add arguments for number of answer candidates
+ validator.add_argument(
+ "num_ans_candidates",
+ type=int,
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
+ )
+ # add arguments for inference method
+ validator.add_argument(
+ "inference_method",
+ type=str,
+ choices=["genearte", "rank"],
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
+ )
+
+ # ====== model specific ======
+ validator.add_argument(
+ "k_test",
+ type=int,
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
+ )
+
+ return validator
diff --git a/video_llama/common/dist_utils.py b/video_llama/common/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9280150bf5122d51bb810a9f0258a233e7088647
--- /dev/null
+++ b/video_llama/common/dist_utils.py
@@ -0,0 +1,137 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import functools
+import os
+
+import torch
+import torch.distributed as dist
+import timm.models.hub as timm_hub
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def init_distributed_mode(args):
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = int(os.environ["LOCAL_RANK"])
+ elif "SLURM_PROCID" in os.environ:
+ args.rank = int(os.environ["SLURM_PROCID"])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print("Not using distributed mode")
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = "nccl"
+ print(
+ "| distributed init (rank {}, world {}): {}".format(
+ args.rank, args.world_size, args.dist_url
+ ),
+ flush=True,
+ )
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ timeout=datetime.timedelta(
+ days=365
+ ), # allow auto-downloading and de-compressing
+ )
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+def get_dist_info():
+ if torch.__version__ < "1.0":
+ initialized = dist._initialized
+ else:
+ initialized = dist.is_initialized()
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else: # non-distributed training
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def main_process(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def download_cached_file(url, check_hash=True, progress=False):
+ """
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
+ """
+
+ def get_cached_file_path():
+ # a hack to sync the file path across processes
+ parts = torch.hub.urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
+
+ return cached_file
+
+ if is_main_process():
+ timm_hub.download_cached_file(url, check_hash, progress)
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ return get_cached_file_path()
diff --git a/video_llama/common/gradcam.py b/video_llama/common/gradcam.py
new file mode 100644
index 0000000000000000000000000000000000000000..d53a5254d4b319eaf2cbfbd081b0ca8e38c5c7a0
--- /dev/null
+++ b/video_llama/common/gradcam.py
@@ -0,0 +1,24 @@
+import numpy as np
+from matplotlib import pyplot as plt
+from scipy.ndimage import filters
+from skimage import transform as skimage_transform
+
+
+def getAttMap(img, attMap, blur=True, overlap=True):
+ attMap -= attMap.min()
+ if attMap.max() > 0:
+ attMap /= attMap.max()
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
+ if blur:
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
+ attMap -= attMap.min()
+ attMap /= attMap.max()
+ cmap = plt.get_cmap("jet")
+ attMapV = cmap(attMap)
+ attMapV = np.delete(attMapV, 3, 2)
+ if overlap:
+ attMap = (
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
+ )
+ return attMap
diff --git a/video_llama/common/logger.py b/video_llama/common/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9de0e65de4601e9ea10f3af6372adabdd09e44c
--- /dev/null
+++ b/video_llama/common/logger.py
@@ -0,0 +1,195 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import logging
+import time
+from collections import defaultdict, deque
+
+import torch
+import torch.distributed as dist
+
+from video_llama.common import dist_utils
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not dist_utils.is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError(
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+ )
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def global_avg(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+ log_msg = [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ if torch.cuda.is_available():
+ log_msg.append("max mem: {memory:.0f}")
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(
+ "{} Total time: {} ({:.4f} s / it)".format(
+ header, total_time_str, total_time / len(iterable)
+ )
+ )
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def setup_logger():
+ logging.basicConfig(
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
+ format="%(asctime)s [%(levelname)s] %(message)s",
+ handlers=[logging.StreamHandler()],
+ )
diff --git a/video_llama/common/optims.py b/video_llama/common/optims.py
new file mode 100644
index 0000000000000000000000000000000000000000..b466e38dc4ceba80ea54759ba608b7281c583bed
--- /dev/null
+++ b/video_llama/common/optims.py
@@ -0,0 +1,119 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import math
+
+from video_llama.common.registry import registry
+
+
+@registry.register_lr_scheduler("linear_warmup_step_lr")
+class LinearWarmupStepLRScheduler:
+ def __init__(
+ self,
+ optimizer,
+ max_epoch,
+ min_lr,
+ init_lr,
+ decay_rate=1,
+ warmup_start_lr=-1,
+ warmup_steps=0,
+ **kwargs
+ ):
+ self.optimizer = optimizer
+
+ self.max_epoch = max_epoch
+ self.min_lr = min_lr
+
+ self.decay_rate = decay_rate
+
+ self.init_lr = init_lr
+ self.warmup_steps = warmup_steps
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+ def step(self, cur_epoch, cur_step):
+ if cur_epoch == 0:
+ warmup_lr_schedule(
+ step=cur_step,
+ optimizer=self.optimizer,
+ max_step=self.warmup_steps,
+ init_lr=self.warmup_start_lr,
+ max_lr=self.init_lr,
+ )
+ else:
+ step_lr_schedule(
+ epoch=cur_epoch,
+ optimizer=self.optimizer,
+ init_lr=self.init_lr,
+ min_lr=self.min_lr,
+ decay_rate=self.decay_rate,
+ )
+
+
+@registry.register_lr_scheduler("linear_warmup_cosine_lr")
+class LinearWarmupCosineLRScheduler:
+ def __init__(
+ self,
+ optimizer,
+ max_epoch,
+ iters_per_epoch,
+ min_lr,
+ init_lr,
+ warmup_steps=0,
+ warmup_start_lr=-1,
+ **kwargs
+ ):
+ self.optimizer = optimizer
+
+ self.max_epoch = max_epoch
+ self.iters_per_epoch = iters_per_epoch
+ self.min_lr = min_lr
+
+ self.init_lr = init_lr
+ self.warmup_steps = warmup_steps
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+ def step(self, cur_epoch, cur_step):
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
+ if total_cur_step < self.warmup_steps:
+ warmup_lr_schedule(
+ step=cur_step,
+ optimizer=self.optimizer,
+ max_step=self.warmup_steps,
+ init_lr=self.warmup_start_lr,
+ max_lr=self.init_lr,
+ )
+ else:
+ cosine_lr_schedule(
+ epoch=total_cur_step,
+ optimizer=self.optimizer,
+ max_epoch=self.max_epoch * self.iters_per_epoch,
+ init_lr=self.init_lr,
+ min_lr=self.min_lr,
+ )
+
+
+def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
+ """Decay the learning rate"""
+ lr = (init_lr - min_lr) * 0.5 * (
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
+ ) + min_lr
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+
+def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
+ """Warmup the learning rate"""
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+
+def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
+ """Decay the learning rate"""
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
diff --git a/video_llama/common/registry.py b/video_llama/common/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0d0171ceb39123c7402a554eb8543ce55ff6881
--- /dev/null
+++ b/video_llama/common/registry.py
@@ -0,0 +1,329 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+
+class Registry:
+ mapping = {
+ "builder_name_mapping": {},
+ "task_name_mapping": {},
+ "processor_name_mapping": {},
+ "model_name_mapping": {},
+ "lr_scheduler_name_mapping": {},
+ "runner_name_mapping": {},
+ "state": {},
+ "paths": {},
+ }
+
+ @classmethod
+ def register_builder(cls, name):
+ r"""Register a dataset builder to registry with key 'name'
+
+ Args:
+ name: Key with which the builder will be registered.
+
+ Usage:
+
+ from video_llama.common.registry import registry
+ from video_llama.datasets.base_dataset_builder import BaseDatasetBuilder
+ """
+
+ def wrap(builder_cls):
+ from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+
+ assert issubclass(
+ builder_cls, BaseDatasetBuilder
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
+ builder_cls
+ )
+ if name in cls.mapping["builder_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["builder_name_mapping"][name]
+ )
+ )
+ cls.mapping["builder_name_mapping"][name] = builder_cls
+ return builder_cls
+
+ return wrap
+
+ @classmethod
+ def register_task(cls, name):
+ r"""Register a task to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from video_llama.common.registry import registry
+ """
+
+ def wrap(task_cls):
+ from video_llama.tasks.base_task import BaseTask
+
+ assert issubclass(
+ task_cls, BaseTask
+ ), "All tasks must inherit BaseTask class"
+ if name in cls.mapping["task_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["task_name_mapping"][name]
+ )
+ )
+ cls.mapping["task_name_mapping"][name] = task_cls
+ return task_cls
+
+ return wrap
+
+ @classmethod
+ def register_model(cls, name):
+ r"""Register a task to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from video_llama.common.registry import registry
+ """
+
+ def wrap(model_cls):
+ from video_llama.models import BaseModel
+
+ assert issubclass(
+ model_cls, BaseModel
+ ), "All models must inherit BaseModel class"
+ if name in cls.mapping["model_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["model_name_mapping"][name]
+ )
+ )
+ cls.mapping["model_name_mapping"][name] = model_cls
+ return model_cls
+
+ return wrap
+
+ @classmethod
+ def register_processor(cls, name):
+ r"""Register a processor to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from video_llama.common.registry import registry
+ """
+
+ def wrap(processor_cls):
+ from video_llama.processors import BaseProcessor
+
+ assert issubclass(
+ processor_cls, BaseProcessor
+ ), "All processors must inherit BaseProcessor class"
+ if name in cls.mapping["processor_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["processor_name_mapping"][name]
+ )
+ )
+ cls.mapping["processor_name_mapping"][name] = processor_cls
+ return processor_cls
+
+ return wrap
+
+ @classmethod
+ def register_lr_scheduler(cls, name):
+ r"""Register a model to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from video_llama.common.registry import registry
+ """
+
+ def wrap(lr_sched_cls):
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
+ )
+ )
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
+ return lr_sched_cls
+
+ return wrap
+
+ @classmethod
+ def register_runner(cls, name):
+ r"""Register a model to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from video_llama.common.registry import registry
+ """
+
+ def wrap(runner_cls):
+ if name in cls.mapping["runner_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["runner_name_mapping"][name]
+ )
+ )
+ cls.mapping["runner_name_mapping"][name] = runner_cls
+ return runner_cls
+
+ return wrap
+
+ @classmethod
+ def register_path(cls, name, path):
+ r"""Register a path to registry with key 'name'
+
+ Args:
+ name: Key with which the path will be registered.
+
+ Usage:
+
+ from video_llama.common.registry import registry
+ """
+ assert isinstance(path, str), "All path must be str."
+ if name in cls.mapping["paths"]:
+ raise KeyError("Name '{}' already registered.".format(name))
+ cls.mapping["paths"][name] = path
+
+ @classmethod
+ def register(cls, name, obj):
+ r"""Register an item to registry with key 'name'
+
+ Args:
+ name: Key with which the item will be registered.
+
+ Usage::
+
+ from video_llama.common.registry import registry
+
+ registry.register("config", {})
+ """
+ path = name.split(".")
+ current = cls.mapping["state"]
+
+ for part in path[:-1]:
+ if part not in current:
+ current[part] = {}
+ current = current[part]
+
+ current[path[-1]] = obj
+
+ # @classmethod
+ # def get_trainer_class(cls, name):
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_builder_class(cls, name):
+ return cls.mapping["builder_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_model_class(cls, name):
+ return cls.mapping["model_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_task_class(cls, name):
+ return cls.mapping["task_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_processor_class(cls, name):
+ return cls.mapping["processor_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_lr_scheduler_class(cls, name):
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_runner_class(cls, name):
+ return cls.mapping["runner_name_mapping"].get(name, None)
+
+ @classmethod
+ def list_runners(cls):
+ return sorted(cls.mapping["runner_name_mapping"].keys())
+
+ @classmethod
+ def list_models(cls):
+ return sorted(cls.mapping["model_name_mapping"].keys())
+
+ @classmethod
+ def list_tasks(cls):
+ return sorted(cls.mapping["task_name_mapping"].keys())
+
+ @classmethod
+ def list_processors(cls):
+ return sorted(cls.mapping["processor_name_mapping"].keys())
+
+ @classmethod
+ def list_lr_schedulers(cls):
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
+
+ @classmethod
+ def list_datasets(cls):
+ return sorted(cls.mapping["builder_name_mapping"].keys())
+
+ @classmethod
+ def get_path(cls, name):
+ return cls.mapping["paths"].get(name, None)
+
+ @classmethod
+ def get(cls, name, default=None, no_warning=False):
+ r"""Get an item from registry with key 'name'
+
+ Args:
+ name (string): Key whose value needs to be retrieved.
+ default: If passed and key is not in registry, default value will
+ be returned with a warning. Default: None
+ no_warning (bool): If passed as True, warning when key doesn't exist
+ will not be generated. Useful for MMF's
+ internal operations. Default: False
+ """
+ original_name = name
+ name = name.split(".")
+ value = cls.mapping["state"]
+ for subname in name:
+ value = value.get(subname, default)
+ if value is default:
+ break
+
+ if (
+ "writer" in cls.mapping["state"]
+ and value == default
+ and no_warning is False
+ ):
+ cls.mapping["state"]["writer"].warning(
+ "Key {} is not present in registry, returning default value "
+ "of {}".format(original_name, default)
+ )
+ return value
+
+ @classmethod
+ def unregister(cls, name):
+ r"""Remove an item from registry with key 'name'
+
+ Args:
+ name: Key which needs to be removed.
+ Usage::
+
+ from mmf.common.registry import registry
+
+ config = registry.unregister("config")
+ """
+ return cls.mapping["state"].pop(name, None)
+
+
+registry = Registry()
diff --git a/video_llama/common/utils.py b/video_llama/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1768fcdfd73b057877a7b0a7c1f10a3aa057caa
--- /dev/null
+++ b/video_llama/common/utils.py
@@ -0,0 +1,424 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import io
+import json
+import logging
+import os
+import pickle
+import re
+import shutil
+import urllib
+import urllib.error
+import urllib.request
+from typing import Optional
+from urllib.parse import urlparse
+
+import numpy as np
+import pandas as pd
+import yaml
+from iopath.common.download import download
+from iopath.common.file_io import file_lock, g_pathmgr
+from video_llama.common.registry import registry
+from torch.utils.model_zoo import tqdm
+from torchvision.datasets.utils import (
+ check_integrity,
+ download_file_from_google_drive,
+ extract_archive,
+)
+
+
+def now():
+ from datetime import datetime
+
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
+
+
+def is_url(url_or_filename):
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+
+def get_cache_path(rel_path):
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
+
+
+def get_abs_path(rel_path):
+ return os.path.join(registry.get_path("library_root"), rel_path)
+
+
+def load_json(filename):
+ with open(filename, "r") as f:
+ return json.load(f)
+
+
+# The following are adapted from torchvision and vissl
+# torchvision: https://github.com/pytorch/vision
+# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
+
+
+def makedir(dir_path):
+ """
+ Create the directory if it does not exist.
+ """
+ is_success = False
+ try:
+ if not g_pathmgr.exists(dir_path):
+ g_pathmgr.mkdirs(dir_path)
+ is_success = True
+ except BaseException:
+ print(f"Error creating directory: {dir_path}")
+ return is_success
+
+
+def get_redirected_url(url: str):
+ """
+ Given a URL, returns the URL it redirects to or the
+ original URL in case of no indirection
+ """
+ import requests
+
+ with requests.Session() as session:
+ with session.get(url, stream=True, allow_redirects=True) as response:
+ if response.history:
+ return response.url
+ else:
+ return url
+
+
+def to_google_drive_download_url(view_url: str) -> str:
+ """
+ Utility function to transform a view URL of google drive
+ to a download URL for google drive
+ Example input:
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
+ Example output:
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
+ """
+ splits = view_url.split("/")
+ assert splits[-1] == "view"
+ file_id = splits[-2]
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
+
+
+def download_google_drive_url(url: str, output_path: str, output_file_name: str):
+ """
+ Download a file from google drive
+ Downloading an URL from google drive requires confirmation when
+ the file of the size is too big (google drive notifies that
+ anti-viral checks cannot be performed on such files)
+ """
+ import requests
+
+ with requests.Session() as session:
+
+ # First get the confirmation token and append it to the URL
+ with session.get(url, stream=True, allow_redirects=True) as response:
+ for k, v in response.cookies.items():
+ if k.startswith("download_warning"):
+ url = url + "&confirm=" + v
+
+ # Then download the content of the file
+ with session.get(url, stream=True, verify=True) as response:
+ makedir(output_path)
+ path = os.path.join(output_path, output_file_name)
+ total_size = int(response.headers.get("Content-length", 0))
+ with open(path, "wb") as file:
+ from tqdm import tqdm
+
+ with tqdm(total=total_size) as progress_bar:
+ for block in response.iter_content(
+ chunk_size=io.DEFAULT_BUFFER_SIZE
+ ):
+ file.write(block)
+ progress_bar.update(len(block))
+
+
+def _get_google_drive_file_id(url: str) -> Optional[str]:
+ parts = urlparse(url)
+
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
+ return None
+
+ match = re.match(r"/file/d/(?P[^/]*)", parts.path)
+ if match is None:
+ return None
+
+ return match.group("id")
+
+
+def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
+ with open(filename, "wb") as fh:
+ with urllib.request.urlopen(
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
+ ) as response:
+ with tqdm(total=response.length) as pbar:
+ for chunk in iter(lambda: response.read(chunk_size), ""):
+ if not chunk:
+ break
+ pbar.update(chunk_size)
+ fh.write(chunk)
+
+
+def download_url(
+ url: str,
+ root: str,
+ filename: Optional[str] = None,
+ md5: Optional[str] = None,
+) -> None:
+ """Download a file from a url and place it in root.
+ Args:
+ url (str): URL to download file from
+ root (str): Directory to place downloaded file in
+ filename (str, optional): Name to save the file under.
+ If None, use the basename of the URL.
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
+ """
+ root = os.path.expanduser(root)
+ if not filename:
+ filename = os.path.basename(url)
+ fpath = os.path.join(root, filename)
+
+ makedir(root)
+
+ # check if file is already present locally
+ if check_integrity(fpath, md5):
+ print("Using downloaded and verified file: " + fpath)
+ return
+
+ # expand redirect chain if needed
+ url = get_redirected_url(url)
+
+ # check if file is located on Google Drive
+ file_id = _get_google_drive_file_id(url)
+ if file_id is not None:
+ return download_file_from_google_drive(file_id, root, filename, md5)
+
+ # download the file
+ try:
+ print("Downloading " + url + " to " + fpath)
+ _urlretrieve(url, fpath)
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
+ if url[:5] == "https":
+ url = url.replace("https:", "http:")
+ print(
+ "Failed download. Trying https -> http instead."
+ " Downloading " + url + " to " + fpath
+ )
+ _urlretrieve(url, fpath)
+ else:
+ raise e
+
+ # check integrity of downloaded file
+ if not check_integrity(fpath, md5):
+ raise RuntimeError("File not found or corrupted.")
+
+
+def download_and_extract_archive(
+ url: str,
+ download_root: str,
+ extract_root: Optional[str] = None,
+ filename: Optional[str] = None,
+ md5: Optional[str] = None,
+ remove_finished: bool = False,
+) -> None:
+ download_root = os.path.expanduser(download_root)
+ if extract_root is None:
+ extract_root = download_root
+ if not filename:
+ filename = os.path.basename(url)
+
+ download_url(url, download_root, filename, md5)
+
+ archive = os.path.join(download_root, filename)
+ print("Extracting {} to {}".format(archive, extract_root))
+ extract_archive(archive, extract_root, remove_finished)
+
+
+def cache_url(url: str, cache_dir: str) -> str:
+ """
+ This implementation downloads the remote resource and caches it locally.
+ The resource will only be downloaded if not previously requested.
+ """
+ parsed_url = urlparse(url)
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
+ makedir(dirname)
+ filename = url.split("/")[-1]
+ cached = os.path.join(dirname, filename)
+ with file_lock(cached):
+ if not os.path.isfile(cached):
+ logging.info(f"Downloading {url} to {cached} ...")
+ cached = download(url, dirname, filename=filename)
+ logging.info(f"URL {url} cached in {cached}")
+ return cached
+
+
+# TODO (prigoyal): convert this into RAII-style API
+def create_file_symlink(file1, file2):
+ """
+ Simply create the symlinks for a given file1 to file2.
+ Useful during model checkpointing to symlinks to the
+ latest successful checkpoint.
+ """
+ try:
+ if g_pathmgr.exists(file2):
+ g_pathmgr.rm(file2)
+ g_pathmgr.symlink(file1, file2)
+ except Exception as e:
+ logging.info(f"Could NOT create symlink. Error: {e}")
+
+
+def save_file(data, filename, append_to_json=True, verbose=True):
+ """
+ Common i/o utility to handle saving data to various file formats.
+ Supported:
+ .pkl, .pickle, .npy, .json
+ Specifically for .json, users have the option to either append (default)
+ or rewrite by passing in Boolean value to append_to_json.
+ """
+ if verbose:
+ logging.info(f"Saving data to file: {filename}")
+ file_ext = os.path.splitext(filename)[1]
+ if file_ext in [".pkl", ".pickle"]:
+ with g_pathmgr.open(filename, "wb") as fopen:
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
+ elif file_ext == ".npy":
+ with g_pathmgr.open(filename, "wb") as fopen:
+ np.save(fopen, data)
+ elif file_ext == ".json":
+ if append_to_json:
+ with g_pathmgr.open(filename, "a") as fopen:
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
+ fopen.flush()
+ else:
+ with g_pathmgr.open(filename, "w") as fopen:
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
+ fopen.flush()
+ elif file_ext == ".yaml":
+ with g_pathmgr.open(filename, "w") as fopen:
+ dump = yaml.dump(data)
+ fopen.write(dump)
+ fopen.flush()
+ else:
+ raise Exception(f"Saving {file_ext} is not supported yet")
+
+ if verbose:
+ logging.info(f"Saved data to file: {filename}")
+
+
+def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
+ """
+ Common i/o utility to handle loading data from various file formats.
+ Supported:
+ .pkl, .pickle, .npy, .json
+ For the npy files, we support reading the files in mmap_mode.
+ If the mmap_mode of reading is not successful, we load data without the
+ mmap_mode.
+ """
+ if verbose:
+ logging.info(f"Loading data from file: {filename}")
+
+ file_ext = os.path.splitext(filename)[1]
+ if file_ext == ".txt":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = fopen.readlines()
+ elif file_ext in [".pkl", ".pickle"]:
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = pickle.load(fopen, encoding="latin1")
+ elif file_ext == ".npy":
+ if mmap_mode:
+ try:
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = np.load(
+ fopen,
+ allow_pickle=allow_pickle,
+ encoding="latin1",
+ mmap_mode=mmap_mode,
+ )
+ except ValueError as e:
+ logging.info(
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
+ )
+ data = np.load(
+ filename,
+ allow_pickle=allow_pickle,
+ encoding="latin1",
+ mmap_mode=mmap_mode,
+ )
+ logging.info("Successfully loaded without g_pathmgr")
+ except Exception:
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
+ else:
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
+ elif file_ext == ".json":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = json.load(fopen)
+ elif file_ext == ".yaml":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
+ elif file_ext == ".csv":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = pd.read_csv(fopen)
+ else:
+ raise Exception(f"Reading from {file_ext} is not supported yet")
+ return data
+
+
+def abspath(resource_path: str):
+ """
+ Make a path absolute, but take into account prefixes like
+ "http://" or "manifold://"
+ """
+ regex = re.compile(r"^\w+://")
+ if regex.match(resource_path) is None:
+ return os.path.abspath(resource_path)
+ else:
+ return resource_path
+
+
+def makedir(dir_path):
+ """
+ Create the directory if it does not exist.
+ """
+ is_success = False
+ try:
+ if not g_pathmgr.exists(dir_path):
+ g_pathmgr.mkdirs(dir_path)
+ is_success = True
+ except BaseException:
+ logging.info(f"Error creating directory: {dir_path}")
+ return is_success
+
+
+def is_url(input_url):
+ """
+ Check if an input string is a url. look for http(s):// and ignoring the case
+ """
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
+ return is_url
+
+
+def cleanup_dir(dir):
+ """
+ Utility for deleting a directory. Useful for cleaning the storage space
+ that contains various training artifacts like checkpoints, data etc.
+ """
+ if os.path.exists(dir):
+ logging.info(f"Deleting directory: {dir}")
+ shutil.rmtree(dir)
+ logging.info(f"Deleted contents of directory: {dir}")
+
+
+def get_file_size(filename):
+ """
+ Given a file, get the size of file in MB
+ """
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
+ return size_in_mb
diff --git a/video_llama/configs/datasets/cc_sbu/align.yaml b/video_llama/configs/datasets/cc_sbu/align.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..180ea8ff0a3548219165e8864d4a59a951206298
--- /dev/null
+++ b/video_llama/configs/datasets/cc_sbu/align.yaml
@@ -0,0 +1,5 @@
+datasets:
+ cc_sbu_align:
+ data_type: images
+ build_info:
+ storage: /path/to/cc_sbu_align_dataset
diff --git a/video_llama/configs/datasets/cc_sbu/defaults.yaml b/video_llama/configs/datasets/cc_sbu/defaults.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..359de601d2511acd6cac34e7ee1c20a1dfe9ae04
--- /dev/null
+++ b/video_llama/configs/datasets/cc_sbu/defaults.yaml
@@ -0,0 +1,5 @@
+datasets:
+ cc_sbu:
+ data_type: images
+ build_info:
+ storage: /path/to/cc_sbu_dataset/{00000..00001}.tar
diff --git a/video_llama/configs/datasets/instruct/llava_instruct.yaml b/video_llama/configs/datasets/instruct/llava_instruct.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0ec4a938e299f98f6d84104c909da210385003af
--- /dev/null
+++ b/video_llama/configs/datasets/instruct/llava_instruct.yaml
@@ -0,0 +1,6 @@
+datasets:
+ llava_instruct:
+ data_type: image
+ build_info:
+ anno_dir: /path/llava_instruct_150k.json
+ videos_dir: /path/train2014/train2014/
diff --git a/video_llama/configs/datasets/instruct/webvid_instruct.yaml b/video_llama/configs/datasets/instruct/webvid_instruct.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9619106ad8cb1000c3c40fa48672fb9247988d74
--- /dev/null
+++ b/video_llama/configs/datasets/instruct/webvid_instruct.yaml
@@ -0,0 +1,6 @@
+datasets:
+ webvid_instruct:
+ data_type: image
+ build_info:
+ anno_dir: /path/webvid_align/videochat_instruct_11k.json
+ videos_dir: /path/webvid_align/videos/
diff --git a/video_llama/configs/datasets/laion/defaults.yaml b/video_llama/configs/datasets/laion/defaults.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7dfff3ba891d96a136510f34b1dcf1b774705f4a
--- /dev/null
+++ b/video_llama/configs/datasets/laion/defaults.yaml
@@ -0,0 +1,5 @@
+datasets:
+ laion:
+ data_type: images
+ build_info:
+ storage: path/laion/laion_dataset/{00000..00001}.tar
diff --git a/video_llama/configs/datasets/webvid/defaults.yaml b/video_llama/configs/datasets/webvid/defaults.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..046ae32dde61e2d79d1f519e1c8c653d8e2b5886
--- /dev/null
+++ b/video_llama/configs/datasets/webvid/defaults.yaml
@@ -0,0 +1,6 @@
+datasets:
+ webvid:
+ data_type: video
+ build_info:
+ anno_dir: path/webvid/webvid_tain_data/annotations/
+ videos_dir: path//webvid/webvid_tain_data/videos/
diff --git a/video_llama/configs/default.yaml b/video_llama/configs/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ff5a6a23fa2e3914938631b96c71fdf723dbbc10
--- /dev/null
+++ b/video_llama/configs/default.yaml
@@ -0,0 +1,5 @@
+env:
+ # For default users
+ # cache_root: "cache"
+ # For internal use with persistent storage
+ cache_root: "/export/home/.cache/minigpt4"
diff --git a/video_llama/configs/models/minigpt4.yaml b/video_llama/configs/models/minigpt4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..358c3f5f7b53251c607ea490ee262a890e531dc6
--- /dev/null
+++ b/video_llama/configs/models/minigpt4.yaml
@@ -0,0 +1,33 @@
+model:
+ arch: mini_gpt4
+
+ # vit encoder
+ image_size: 224
+ drop_path_rate: 0
+ use_grad_checkpoint: False
+ vit_precision: "fp16"
+ freeze_vit: True
+ freeze_qformer: True
+
+ # Q-Former
+ num_query_token: 32
+
+ # Vicuna
+ llama_model: "ckpt/vicuna-13b/"
+
+ # generation configs
+ prompt: ""
+
+preprocess:
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 224
+ eval:
+ name: "blip2_image_eval"
+ image_size: 224
+ text_processor:
+ train:
+ name: "blip_caption"
+ eval:
+ name: "blip_caption"
diff --git a/video_llama/configs/models/video_llama.yaml b/video_llama/configs/models/video_llama.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..27ce07c3fbaa5a867279572c5b7d1dc31d469791
--- /dev/null
+++ b/video_llama/configs/models/video_llama.yaml
@@ -0,0 +1,36 @@
+model:
+ arch: video_llama
+
+ # vit encoder
+ image_size: 224
+ drop_path_rate: 0
+ use_grad_checkpoint: False
+ vit_precision: "fp16"
+ freeze_vit: True
+ freeze_qformer: True
+
+ # Q-Former
+ num_query_token: 32
+
+ # Vicuna
+ llama_model: "ckpt/vicuna-7b/"
+
+ # generation configs
+ prompt: ""
+
+preprocess:
+ vis_processor:
+ train:
+ name: "alpro_video_train"
+ image_size: 224
+ n_frms: 8
+ eval:
+ name: "alpro_video_eval"
+ image_size: 224
+ n_frms: 8
+ text_processor:
+ train:
+ name: "blip_caption"
+ eval:
+ name: "blip_caption"
+
\ No newline at end of file
diff --git a/video_llama/conversation/__init__.py b/video_llama/conversation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/video_llama/conversation/conversation_video.py b/video_llama/conversation/conversation_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd96a7a275f691519cd86200d7ed178d7cd2b75f
--- /dev/null
+++ b/video_llama/conversation/conversation_video.py
@@ -0,0 +1,248 @@
+"""
+Conversation prompt template of Video-LLaMA.
+Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
+"""
+import argparse
+import time
+from PIL import Image
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
+from transformers import StoppingCriteria, StoppingCriteriaList
+
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple, Any
+import os
+from video_llama.common.registry import registry
+from video_llama.processors.video_processor import ToTHWC,ToUint8,load_video
+from video_llama.processors import Blip2ImageEvalProcessor
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ # system_img: List[Image.Image] = []
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+
+ skip_next: bool = False
+ conv_id: Any = None
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ # system_img=self.system_img,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ conv_id=self.conv_id)
+
+ def dict(self):
+ return {
+ "system": self.system,
+ # "system_img": self.system_img,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ "conv_id": self.conv_id,
+ }
+
+
+class StoppingCriteriaSub(StoppingCriteria):
+
+ def __init__(self, stops=[], encounters=1):
+ super().__init__()
+ self.stops = stops
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ for stop in self.stops:
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
+ return True
+
+ return False
+
+
+CONV_VISION = Conversation(
+ system="Give the following image: ImageContent. "
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+default_conversation = Conversation(
+ system="",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+class Chat:
+ def __init__(self, model, vis_processor, device='cuda:0'):
+ self.device = device
+ self.model = model
+ self.vis_processor = vis_processor
+ self.image_vis_processor = Blip2ImageEvalProcessor()
+ stop_words_ids = [torch.tensor([835]).to(self.device),
+ torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
+
+ def ask(self, text, conv):
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
+ and ('' in conv.messages[-1][1] or '' in conv.messages[-1][1]): # last message is image.
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
+ else:
+ conv.append_message(conv.roles[0], text)
+
+ def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
+ conv.append_message(conv.roles[1], None)
+ embs = self.get_context_emb(conv, img_list)
+
+ current_max_len = embs.shape[1] + max_new_tokens
+ if current_max_len - max_length > 0:
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
+ 'The model will not see the contexts outside the range.')
+ begin_idx = max(0, current_max_len - max_length)
+
+ embs = embs[:, begin_idx:]
+
+ outputs = self.model.llama_model.generate(
+ inputs_embeds=embs,
+ max_new_tokens=max_new_tokens,
+ stopping_criteria=self.stopping_criteria,
+ num_beams=num_beams,
+ do_sample=True,
+ min_length=min_length,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ length_penalty=length_penalty,
+ temperature=temperature,
+ )
+ output_token = outputs[0]
+ if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it
+ output_token = output_token[1:]
+ if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it
+ output_token = output_token[1:]
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
+ output_text = output_text.split('Assistant:')[-1].strip()
+ conv.messages[-1][1] = output_text
+ return output_text, output_token.cpu().numpy()
+
+ def upload_video(self, video, conv, img_list):
+
+ msg = ""
+ if isinstance(video, str): # is a video path
+ ext = os.path.splitext(video)[-1].lower()
+ print(video)
+ # image = self.vis_processor(image).unsqueeze(0).to(self.device)
+ video, msg = load_video(
+ video_path=video,
+ n_frms=8,
+ height=224,
+ width=224,
+ sampling ="uniform", return_msg = True
+ )
+ video = self.vis_processor.transform(video)
+ video = video.unsqueeze(0).to(self.device)
+ # print(image)
+ else:
+ raise NotImplementedError
+
+ image_emb, _ = self.model.encode_img(video)
+ img_list.append(image_emb)
+ conv.append_message(conv.roles[0], " "+ msg)
+ return "Received."
+
+ def upload_img(self, image, conv, img_list):
+
+ msg = ""
+ if isinstance(image, str): # is a image path
+ raw_image = Image.open(image).convert('RGB') # 增加一个时间维度
+ image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
+ elif isinstance(image, Image.Image):
+ raw_image = image
+ image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
+ elif isinstance(image, torch.Tensor):
+ if len(image.shape) == 3:
+ image = image.unsqueeze(0)
+ image = image.to(self.device)
+ else:
+ raise NotImplementedError
+
+ image_emb, _ = self.model.encode_img(image)
+ img_list.append(image_emb)
+ # Todo msg=""
+ conv.append_message(conv.roles[0], " "+ msg)
+
+ return "Received."
+
+ def get_context_emb(self, conv, img_list):
+ prompt = conv.get_prompt()
+ prompt_segs = prompt.split('')
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
+ seg_tokens = [
+ self.model.llama_tokenizer(
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
+ # only add bos to the first seg
+ for i, seg in enumerate(prompt_segs)
+ ]
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
+ mixed_embs = torch.cat(mixed_embs, dim=1)
+ return mixed_embs
+
+
diff --git a/video_llama/datasets/__init__.py b/video_llama/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/video_llama/datasets/builders/__init__.py b/video_llama/datasets/builders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b160d0b8ad5793e368d8b2d26ff9829fa3ddd9a
--- /dev/null
+++ b/video_llama/datasets/builders/__init__.py
@@ -0,0 +1,77 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from video_llama.datasets.builders.base_dataset_builder import load_dataset_config
+from video_llama.datasets.builders.image_text_pair_builder import (
+ CCSBUBuilder,
+ LaionBuilder,
+ CCSBUAlignBuilder
+)
+from video_llama.datasets.builders.video_caption_builder import WebvidBuilder
+from video_llama.common.registry import registry
+from video_llama.datasets.builders.instruct_builder import WebvidInstruct_Builder,LlavaInstruct_Builder
+__all__ = [
+ "CCSBUBuilder",
+ "LaionBuilder",
+ "CCSBUAlignBuilder",
+ "WebvidBuilder",
+ "LlavaInstruct_Builder",
+ "WebvidInstruct_Builder"
+
+]
+
+
+def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
+ """
+ Example
+
+ >>> dataset = load_dataset("coco_caption", cfg=None)
+ >>> splits = dataset.keys()
+ >>> print([len(dataset[split]) for split in splits])
+
+ """
+ if cfg_path is None:
+ cfg = None
+ else:
+ cfg = load_dataset_config(cfg_path)
+
+ try:
+ builder = registry.get_builder_class(name)(cfg)
+ except TypeError:
+ print(
+ f"Dataset {name} not found. Available datasets:\n"
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
+ )
+ exit(1)
+
+ if vis_path is not None:
+ if data_type is None:
+ # use default data type in the config
+ data_type = builder.config.data_type
+
+ assert (
+ data_type in builder.config.build_info
+ ), f"Invalid data_type {data_type} for {name}."
+
+ builder.config.build_info.get(data_type).storage = vis_path
+
+ dataset = builder.build_datasets()
+ return dataset
+
+
+class DatasetZoo:
+ def __init__(self) -> None:
+ self.dataset_zoo = {
+ k: list(v.DATASET_CONFIG_DICT.keys())
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
+ }
+
+ def get_names(self):
+ return list(self.dataset_zoo.keys())
+
+
+dataset_zoo = DatasetZoo()
diff --git a/video_llama/datasets/builders/base_dataset_builder.py b/video_llama/datasets/builders/base_dataset_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..86c2cf688e9bcac67138aa32c58927e4a5ddebba
--- /dev/null
+++ b/video_llama/datasets/builders/base_dataset_builder.py
@@ -0,0 +1,236 @@
+"""
+ This file is from
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+import shutil
+import warnings
+
+from omegaconf import OmegaConf
+import torch.distributed as dist
+from torchvision.datasets.utils import download_url
+
+import video_llama.common.utils as utils
+from video_llama.common.dist_utils import is_dist_avail_and_initialized, is_main_process
+from video_llama.common.registry import registry
+from video_llama.processors.base_processor import BaseProcessor
+
+
+
+class BaseDatasetBuilder:
+ train_dataset_cls, eval_dataset_cls = None, None
+
+ def __init__(self, cfg=None):
+ super().__init__()
+
+ if cfg is None:
+ # help to create datasets from default config.
+ self.config = load_dataset_config(self.default_config_path())
+ elif isinstance(cfg, str):
+ self.config = load_dataset_config(cfg)
+ else:
+ # when called from task.build_dataset()
+ self.config = cfg
+
+ self.data_type = self.config.data_type
+
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+
+ def build_datasets(self):
+ # download, split, etc...
+ # only called on 1 GPU/TPU in distributed
+
+ if is_main_process():
+ self._download_data()
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ datasets = self.build() # dataset['train'/'val'/'test']
+
+ return datasets
+
+ def build_processors(self):
+ vis_proc_cfg = self.config.get("vis_processor")
+ txt_proc_cfg = self.config.get("text_processor")
+
+ if vis_proc_cfg is not None:
+ vis_train_cfg = vis_proc_cfg.get("train")
+ vis_eval_cfg = vis_proc_cfg.get("eval")
+
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
+
+ if txt_proc_cfg is not None:
+ txt_train_cfg = txt_proc_cfg.get("train")
+ txt_eval_cfg = txt_proc_cfg.get("eval")
+
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
+
+ @staticmethod
+ def _build_proc_from_cfg(cfg):
+ return (
+ registry.get_processor_class(cfg.name).from_config(cfg)
+ if cfg is not None
+ else None
+ )
+
+ @classmethod
+ def default_config_path(cls, type="default"):
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
+
+ def _download_data(self):
+ self._download_ann()
+ self._download_vis()
+
+ def _download_ann(self):
+ """
+ Download annotation files if necessary.
+ All the vision-language datasets should have annotations of unified format.
+
+ storage_path can be:
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
+
+ Local annotation paths should be relative.
+ """
+ anns = self.config.build_info.annotations
+
+ splits = anns.keys()
+
+ cache_root = registry.get_path("cache_root")
+
+ for split in splits:
+ info = anns[split]
+
+ urls, storage_paths = info.get("url", None), info.storage
+
+ if isinstance(urls, str):
+ urls = [urls]
+ if isinstance(storage_paths, str):
+ storage_paths = [storage_paths]
+
+ assert len(urls) == len(storage_paths)
+
+ for url_or_filename, storage_path in zip(urls, storage_paths):
+ # if storage_path is relative, make it full by prefixing with cache_root.
+ if not os.path.isabs(storage_path):
+ storage_path = os.path.join(cache_root, storage_path)
+
+ dirname = os.path.dirname(storage_path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ if os.path.isfile(url_or_filename):
+ src, dst = url_or_filename, storage_path
+ if not os.path.exists(dst):
+ shutil.copyfile(src=src, dst=dst)
+ else:
+ logging.info("Using existing file {}.".format(dst))
+ else:
+ if os.path.isdir(storage_path):
+ # if only dirname is provided, suffix with basename of URL.
+ raise ValueError(
+ "Expecting storage_path to be a file path, got directory {}".format(
+ storage_path
+ )
+ )
+ else:
+ filename = os.path.basename(storage_path)
+
+ download_url(url=url_or_filename, root=dirname, filename=filename)
+
+ def _download_vis(self):
+
+ storage_path = self.config.build_info.get(self.data_type).storage
+ storage_path = utils.get_cache_path(storage_path)
+
+ if not os.path.exists(storage_path):
+ warnings.warn(
+ f"""
+ The specified path {storage_path} for visual inputs does not exist.
+ Please provide a correct path to the visual inputs or
+ refer to datasets/download_scripts/README.md for downloading instructions.
+ """
+ )
+
+ def build(self):
+ """
+ Create by split datasets inheriting torch.utils.data.Datasets.
+
+ # build() can be dataset-specific. Overwrite to customize.
+ """
+ self.build_processors()
+
+ build_info = self.config.build_info
+
+ ann_info = build_info.annotations
+ vis_info = build_info.get(self.data_type)
+
+ datasets = dict()
+ for split in ann_info.keys():
+ if split not in ["train", "val", "test"]:
+ continue
+
+ is_train = split == "train"
+
+ # processors
+ vis_processor = (
+ self.vis_processors["train"]
+ if is_train
+ else self.vis_processors["eval"]
+ )
+ text_processor = (
+ self.text_processors["train"]
+ if is_train
+ else self.text_processors["eval"]
+ )
+
+ # annotation path
+ ann_paths = ann_info.get(split).storage
+ if isinstance(ann_paths, str):
+ ann_paths = [ann_paths]
+
+ abs_ann_paths = []
+ for ann_path in ann_paths:
+ if not os.path.isabs(ann_path):
+ ann_path = utils.get_cache_path(ann_path)
+ abs_ann_paths.append(ann_path)
+ ann_paths = abs_ann_paths
+
+ # visual data storage path
+ vis_path = os.path.join(vis_info.storage, split)
+
+ if not os.path.isabs(vis_path):
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
+ vis_path = utils.get_cache_path(vis_path)
+
+ if not os.path.exists(vis_path):
+ warnings.warn("storage path {} does not exist.".format(vis_path))
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
+ datasets[split] = dataset_cls(
+ vis_processor=vis_processor,
+ text_processor=text_processor,
+ ann_paths=ann_paths,
+ vis_root=vis_path,
+ )
+
+ return datasets
+
+
+def load_dataset_config(cfg_path):
+ cfg = OmegaConf.load(cfg_path).datasets
+ cfg = cfg[list(cfg.keys())[0]]
+
+ return cfg
diff --git a/video_llama/datasets/builders/image_text_pair_builder.py b/video_llama/datasets/builders/image_text_pair_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f93bf8f0dd51318c01940f07dc10e9dda2dd275
--- /dev/null
+++ b/video_llama/datasets/builders/image_text_pair_builder.py
@@ -0,0 +1,106 @@
+import os
+import logging
+import warnings
+
+from video_llama.common.registry import registry
+from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+from video_llama.datasets.datasets.laion_dataset import LaionDataset
+from video_llama.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
+
+
+@registry.register_builder("cc_sbu")
+class CCSBUBuilder(BaseDatasetBuilder):
+ train_dataset_cls = CCSBUDataset
+
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
+
+ def _download_ann(self):
+ pass
+
+ def _download_vis(self):
+ pass
+
+ def build(self):
+ self.build_processors()
+
+ build_info = self.config.build_info
+
+ datasets = dict()
+ split = "train"
+
+ # create datasets
+ # [NOTE] return inner_datasets (wds.DataPipeline)
+ dataset_cls = self.train_dataset_cls
+ datasets[split] = dataset_cls(
+ vis_processor=self.vis_processors[split],
+ text_processor=self.text_processors[split],
+ location=build_info.storage,
+ ).inner_dataset
+
+ return datasets
+
+
+@registry.register_builder("laion")
+class LaionBuilder(BaseDatasetBuilder):
+ train_dataset_cls = LaionDataset
+
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
+
+ def _download_ann(self):
+ pass
+
+ def _download_vis(self):
+ pass
+
+ def build(self):
+ self.build_processors()
+
+ build_info = self.config.build_info
+
+ datasets = dict()
+ split = "train"
+
+ # create datasets
+ # [NOTE] return inner_datasets (wds.DataPipeline)
+ dataset_cls = self.train_dataset_cls
+ datasets[split] = dataset_cls(
+ vis_processor=self.vis_processors[split],
+ text_processor=self.text_processors[split],
+ location=build_info.storage,
+ ).inner_dataset
+
+ return datasets
+
+
+@registry.register_builder("cc_sbu_align")
+class CCSBUAlignBuilder(BaseDatasetBuilder):
+ train_dataset_cls = CCSBUAlignDataset
+
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/cc_sbu/align.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+
+ build_info = self.config.build_info
+ storage_path = build_info.storage
+
+ datasets = dict()
+
+ if not os.path.exists(storage_path):
+ warnings.warn("storage path {} does not exist.".format(storage_path))
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
+ vis_root=os.path.join(storage_path, 'image'),
+ )
+
+ return datasets
+
diff --git a/video_llama/datasets/builders/instruct_builder.py b/video_llama/datasets/builders/instruct_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b95238785386af934721d65cc7859d60f57023ae
--- /dev/null
+++ b/video_llama/datasets/builders/instruct_builder.py
@@ -0,0 +1,78 @@
+import os
+import logging
+import warnings
+
+from video_llama.common.registry import registry
+from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+from video_llama.datasets.datasets.laion_dataset import LaionDataset
+from video_llama.datasets.datasets.llava_instruct_dataset import Instruct_Dataset
+from video_llama.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset
+
+@registry.register_builder("instruct")
+class Instruct_Builder(BaseDatasetBuilder):
+ train_dataset_cls = Instruct_Dataset
+
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"}
+
+ def _download_ann(self):
+ pass
+
+ def _download_vis(self):
+ pass
+
+ def build(self):
+ self.build_processors()
+ datasets = dict()
+ split = "train"
+
+ build_info = self.config.build_info
+ dataset_cls = self.train_dataset_cls
+ if self.config.num_video_query_token:
+ num_video_query_token = self.config.num_video_query_token
+ else:
+ num_video_query_token = 32
+
+ if self.config.tokenizer_name:
+ tokenizer_name = self.config.tokenizer_name
+ else:
+ tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/'
+
+
+ datasets[split] = dataset_cls(
+ vis_processor=self.vis_processors[split],
+ text_processor=self.text_processors[split],
+ vis_root=build_info.videos_dir,
+ ann_root=build_info.anno_dir,
+ num_video_query_token = num_video_query_token,
+ tokenizer_name = tokenizer_name,
+ data_type = self.config.data_type
+ )
+
+ return datasets
+
+@registry.register_builder("webvid_instruct")
+class WebvidInstruct_Builder(Instruct_Builder):
+ train_dataset_cls = Video_Instruct_Dataset
+
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/instruct/webvid_instruct.yaml",
+ }
+
+@registry.register_builder("webvid_instruct_zh")
+class WebvidInstruct_zh_Builder(Instruct_Builder):
+ train_dataset_cls = Video_Instruct_Dataset
+
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/instruct/webvid_instruct.yaml",
+ }
+
+
+
+@registry.register_builder("llava_instruct")
+class LlavaInstruct_Builder(Instruct_Builder):
+ train_dataset_cls = Instruct_Dataset
+
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/instruct/llava_instruct.yaml",
+ }
+
diff --git a/video_llama/datasets/builders/video_caption_builder.py b/video_llama/datasets/builders/video_caption_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e73ef9db75c4699e1a763b459a7c96c1147e192b
--- /dev/null
+++ b/video_llama/datasets/builders/video_caption_builder.py
@@ -0,0 +1,34 @@
+import os
+import logging
+import warnings
+
+from video_llama.common.registry import registry
+from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+from video_llama.datasets.datasets.webvid_datasets import WebvidDataset
+
+@registry.register_builder("webvid")
+class WebvidBuilder(BaseDatasetBuilder):
+ train_dataset_cls = WebvidDataset
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"}
+
+ def _download_ann(self):
+ pass
+
+ def _download_vis(self):
+ pass
+
+ def build(self):
+ self.build_processors()
+ datasets = dict()
+ split = "train"
+
+ build_info = self.config.build_info
+ dataset_cls = self.train_dataset_cls
+ datasets[split] = dataset_cls(
+ vis_processor=self.vis_processors[split],
+ text_processor=self.text_processors[split],
+ vis_root=build_info.videos_dir,
+ ann_root=build_info.anno_dir
+ )
+
+ return datasets
\ No newline at end of file
diff --git a/video_llama/datasets/data_utils.py b/video_llama/datasets/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fe6a567bae667f00ef0ee1d4d9075649107b471
--- /dev/null
+++ b/video_llama/datasets/data_utils.py
@@ -0,0 +1,196 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import gzip
+import logging
+import os
+import random as rnd
+import tarfile
+import zipfile
+import random
+from typing import List
+from tqdm import tqdm
+
+import decord
+from decord import VideoReader
+import webdataset as wds
+import numpy as np
+import torch
+from torch.utils.data.dataset import IterableDataset
+
+from video_llama.common.registry import registry
+from video_llama.datasets.datasets.base_dataset import ConcatDataset
+
+
+decord.bridge.set_bridge("torch")
+MAX_INT = registry.get("MAX_INT")
+
+
+class ChainDataset(wds.DataPipeline):
+ r"""Dataset for chaining multiple :class:`DataPipeline` s.
+
+ This class is useful to assemble different existing dataset streams. The
+ chaining operation is done on-the-fly, so concatenating large-scale
+ datasets with this class will be efficient.
+
+ Args:
+ datasets (iterable of IterableDataset): datasets to be chained together
+ """
+ def __init__(self, datasets: List[wds.DataPipeline]) -> None:
+ super().__init__()
+ self.datasets = datasets
+ self.prob = []
+ self.names = []
+ for dataset in self.datasets:
+ if hasattr(dataset, 'name'):
+ self.names.append(dataset.name)
+ else:
+ self.names.append('Unknown')
+ if hasattr(dataset, 'sample_ratio'):
+ self.prob.append(dataset.sample_ratio)
+ else:
+ self.prob.append(1)
+ logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
+
+ def __iter__(self):
+ datastreams = [iter(dataset) for dataset in self.datasets]
+ while True:
+ select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
+ yield next(select_datastream)
+
+
+def apply_to_sample(f, sample):
+ if len(sample) == 0:
+ return {}
+
+ def _apply(x):
+ if torch.is_tensor(x):
+ return f(x)
+ elif isinstance(x, dict):
+ return {key: _apply(value) for key, value in x.items()}
+ elif isinstance(x, list):
+ return [_apply(x) for x in x]
+ else:
+ return x
+
+ return _apply(sample)
+
+
+def move_to_cuda(sample):
+ def _move_to_cuda(tensor):
+ return tensor.cuda()
+
+ return apply_to_sample(_move_to_cuda, sample)
+
+
+def prepare_sample(samples, cuda_enabled=True):
+ if cuda_enabled:
+ samples = move_to_cuda(samples)
+
+ # TODO fp16 support
+
+ return samples
+
+
+def reorg_datasets_by_split(datasets):
+ """
+ Organizes datasets by split.
+
+ Args:
+ datasets: dict of torch.utils.data.Dataset objects by name.
+
+ Returns:
+ Dict of datasets by split {split_name: List[Datasets]}.
+ """
+ # if len(datasets) == 1:
+ # return datasets[list(datasets.keys())[0]]
+ # else:
+ reorg_datasets = dict()
+
+ # reorganize by split
+ for _, dataset in datasets.items():
+ for split_name, dataset_split in dataset.items():
+ if split_name not in reorg_datasets:
+ reorg_datasets[split_name] = [dataset_split]
+ else:
+ reorg_datasets[split_name].append(dataset_split)
+
+ return reorg_datasets
+
+
+def concat_datasets(datasets):
+ """
+ Concatenates multiple datasets into a single dataset.
+
+ It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
+ generic IterableDataset because it requires creating separate samplers.
+
+ Now only supports conctenating training datasets and assuming validation and testing
+ have only a single dataset. This is because metrics should not be computed on the concatenated
+ datasets.
+
+ Args:
+ datasets: dict of torch.utils.data.Dataset objects by split.
+
+ Returns:
+ Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
+ "val" and "test" remain the same.
+
+ If the input training datasets contain both map-style and DataPipeline datasets, returns
+ a tuple, where the first element is a concatenated map-style dataset and the second
+ element is a chained DataPipeline dataset.
+
+ """
+ # concatenate datasets in the same split
+ for split_name in datasets:
+ if split_name != "train":
+ assert (
+ len(datasets[split_name]) == 1
+ ), "Do not support multiple {} datasets.".format(split_name)
+ datasets[split_name] = datasets[split_name][0]
+ else:
+ iterable_datasets, map_datasets = [], []
+ for dataset in datasets[split_name]:
+ if isinstance(dataset, wds.DataPipeline):
+ logging.info(
+ "Dataset {} is IterableDataset, can't be concatenated.".format(
+ dataset
+ )
+ )
+ iterable_datasets.append(dataset)
+ elif isinstance(dataset, IterableDataset):
+ raise NotImplementedError(
+ "Do not support concatenation of generic IterableDataset."
+ )
+ else:
+ map_datasets.append(dataset)
+
+ # if len(iterable_datasets) > 0:
+ # concatenate map-style datasets and iterable-style datasets separately
+ if len(iterable_datasets) > 1:
+ chained_datasets = (
+ ChainDataset(iterable_datasets)
+ )
+ elif len(iterable_datasets) == 1:
+ chained_datasets = iterable_datasets[0]
+ else:
+ chained_datasets = None
+
+ concat_datasets = (
+ ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
+ )
+
+ train_datasets = concat_datasets, chained_datasets
+ train_datasets = tuple([x for x in train_datasets if x is not None])
+ train_datasets = (
+ train_datasets[0] if len(train_datasets) == 1 else train_datasets
+ )
+
+ datasets[split_name] = train_datasets
+
+ return datasets
+
diff --git a/video_llama/datasets/datasets/__init__.py b/video_llama/datasets/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/video_llama/datasets/datasets/base_dataset.py b/video_llama/datasets/datasets/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae2a8d0e21370129c0182cddc427eb293bbe5982
--- /dev/null
+++ b/video_llama/datasets/datasets/base_dataset.py
@@ -0,0 +1,68 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import json
+from typing import Iterable
+
+from torch.utils.data import Dataset, ConcatDataset
+from torch.utils.data.dataloader import default_collate
+
+
+class BaseDataset(Dataset):
+ def __init__(
+ self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
+ ):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.annotation = []
+ for ann_path in ann_paths:
+ self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self._add_instance_ids()
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def collater(self, samples):
+ return default_collate(samples)
+
+ def set_processors(self, vis_processor, text_processor):
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ def _add_instance_ids(self, key="instance_id"):
+ for idx, ann in enumerate(self.annotation):
+ ann[key] = str(idx)
+
+
+class ConcatDataset(ConcatDataset):
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
+ super().__init__(datasets)
+
+ def collater(self, samples):
+ # TODO For now only supports datasets with same underlying collater implementations
+
+ all_keys = set()
+ for s in samples:
+ all_keys.update(s)
+
+ shared_keys = all_keys
+ for s in samples:
+ shared_keys = shared_keys & set(s.keys())
+
+ samples_shared_keys = []
+ for s in samples:
+ samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
+
+ return self.datasets[0].collater(samples_shared_keys)
diff --git a/video_llama/datasets/datasets/caption_datasets.py b/video_llama/datasets/datasets/caption_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78105896012b87174a547a365451d5d67fd8e93
--- /dev/null
+++ b/video_llama/datasets/datasets/caption_datasets.py
@@ -0,0 +1,85 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+from collections import OrderedDict
+
+from video_llama.datasets.datasets.base_dataset import BaseDataset
+from PIL import Image
+
+
+class __DisplMixin:
+ def displ_item(self, index):
+ sample, ann = self.__getitem__(index), self.annotation[index]
+
+ return OrderedDict(
+ {
+ "file": ann["image"],
+ "caption": ann["caption"],
+ "image": sample["image"],
+ }
+ )
+
+
+class CaptionDataset(BaseDataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ self.img_ids = {}
+ n = 0
+ for ann in self.annotation:
+ img_id = ann["image_id"]
+ if img_id not in self.img_ids.keys():
+ self.img_ids[img_id] = n
+ n += 1
+
+ def __getitem__(self, index):
+
+ # TODO this assumes image input, not general enough
+ ann = self.annotation[index]
+
+ img_file = '{:0>12}.jpg'.format(ann["image_id"])
+ image_path = os.path.join(self.vis_root, img_file)
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ caption = self.text_processor(ann["caption"])
+
+ return {
+ "image": image,
+ "text_input": caption,
+ "image_id": self.img_ids[ann["image_id"]],
+ }
+
+
+class CaptionEvalDataset(BaseDataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ """
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.vis_root, ann["image"])
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+
+ return {
+ "image": image,
+ "image_id": ann["image_id"],
+ "instance_id": ann["instance_id"],
+ }
diff --git a/video_llama/datasets/datasets/cc_sbu_dataset.py b/video_llama/datasets/datasets/cc_sbu_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..59311479552e55b0e6f7d9aec3d70b3d993f92d1
--- /dev/null
+++ b/video_llama/datasets/datasets/cc_sbu_dataset.py
@@ -0,0 +1,49 @@
+import os
+from PIL import Image
+import webdataset as wds
+from video_llama.datasets.datasets.base_dataset import BaseDataset
+from video_llama.datasets.datasets.caption_datasets import CaptionDataset
+
+
+class CCSBUDataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, location):
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
+
+ self.inner_dataset = wds.DataPipeline(
+ wds.ResampledShards(location),
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
+ wds.shuffle(1000, handler=wds.warn_and_continue),
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
+ )
+
+ def to_dict(self, sample):
+ return {
+ "image": sample[0],
+ "text_input": self.text_processor(sample[1]["caption"]),
+ "type":'image',
+ }
+
+
+class CCSBUAlignDataset(CaptionDataset):
+
+ def __getitem__(self, index):
+
+ # TODO this assumes image input, not general enough
+ ann = self.annotation[index]
+
+ img_file = '{}.jpg'.format(ann["image_id"])
+ image_path = os.path.join(self.vis_root, img_file)
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ caption = ann["caption"]
+
+ return {
+ "image": image,
+ "text_input": caption,
+ "image_id": self.img_ids[ann["image_id"]],
+ "type":'image',
+ }
\ No newline at end of file
diff --git a/video_llama/datasets/datasets/dataloader_utils.py b/video_llama/datasets/datasets/dataloader_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e2f574e24d2a32a18533a11492cfd481ff2cfbb
--- /dev/null
+++ b/video_llama/datasets/datasets/dataloader_utils.py
@@ -0,0 +1,162 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import time
+import random
+import torch
+from video_llama.datasets.data_utils import move_to_cuda
+from torch.utils.data import DataLoader
+
+
+class MultiIterLoader:
+ """
+ A simple wrapper for iterating over multiple iterators.
+
+ Args:
+ loaders (List[Loader]): List of Iterator loaders.
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
+ """
+
+ def __init__(self, loaders, ratios=None):
+ # assert all loaders has __next__ method
+ for loader in loaders:
+ assert hasattr(
+ loader, "__next__"
+ ), "Loader {} has no __next__ method.".format(loader)
+
+ if ratios is None:
+ ratios = [1.0] * len(loaders)
+ else:
+ assert len(ratios) == len(loaders)
+ ratios = [float(ratio) / sum(ratios) for ratio in ratios]
+
+ self.loaders = loaders
+ self.ratios = ratios
+
+ def __next__(self):
+ # random sample from each loader by ratio
+ loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
+ return next(self.loaders[loader_idx])
+
+
+class PrefetchLoader(object):
+ """
+ Modified from https://github.com/ChenRocks/UNITER.
+
+ overlap compute and cuda data transfer
+ (copied and then modified from nvidia apex)
+ """
+
+ def __init__(self, loader):
+ self.loader = loader
+ self.stream = torch.cuda.Stream()
+
+ def __iter__(self):
+ loader_it = iter(self.loader)
+ self.preload(loader_it)
+ batch = self.next(loader_it)
+ while batch is not None:
+ is_tuple = isinstance(batch, tuple)
+ if is_tuple:
+ task, batch = batch
+
+ if is_tuple:
+ yield task, batch
+ else:
+ yield batch
+ batch = self.next(loader_it)
+
+ def __len__(self):
+ return len(self.loader)
+
+ def preload(self, it):
+ try:
+ self.batch = next(it)
+ except StopIteration:
+ self.batch = None
+ return
+ # if record_stream() doesn't work, another option is to make sure
+ # device inputs are created on the main stream.
+ # self.next_input_gpu = torch.empty_like(self.next_input,
+ # device='cuda')
+ # self.next_target_gpu = torch.empty_like(self.next_target,
+ # device='cuda')
+ # Need to make sure the memory allocated for next_* is not still in use
+ # by the main stream at the time we start copying to next_*:
+ # self.stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self.stream):
+ self.batch = move_to_cuda(self.batch)
+ # more code for the alternative if record_stream() doesn't work:
+ # copy_ will record the use of the pinned source tensor in this
+ # side stream.
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
+ # self.next_input = self.next_input_gpu
+ # self.next_target = self.next_target_gpu
+
+ def next(self, it):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is not None:
+ record_cuda_stream(batch)
+ self.preload(it)
+ return batch
+
+ def __getattr__(self, name):
+ method = self.loader.__getattribute__(name)
+ return method
+
+
+def record_cuda_stream(batch):
+ if isinstance(batch, torch.Tensor):
+ batch.record_stream(torch.cuda.current_stream())
+ elif isinstance(batch, list) or isinstance(batch, tuple):
+ for t in batch:
+ record_cuda_stream(t)
+ elif isinstance(batch, dict):
+ for t in batch.values():
+ record_cuda_stream(t)
+ else:
+ pass
+
+
+class IterLoader:
+ """
+ A wrapper to convert DataLoader as an infinite iterator.
+
+ Modified from:
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
+ """
+
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
+ self._dataloader = dataloader
+ self.iter_loader = iter(self._dataloader)
+ self._use_distributed = use_distributed
+ self._epoch = 0
+
+ @property
+ def epoch(self) -> int:
+ return self._epoch
+
+ def __next__(self):
+ try:
+ data = next(self.iter_loader)
+ except StopIteration:
+ self._epoch += 1
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
+ self._dataloader.sampler.set_epoch(self._epoch)
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ self.iter_loader = iter(self._dataloader)
+ data = next(self.iter_loader)
+
+ return data
+
+ def __iter__(self):
+ return self
+
+ def __len__(self):
+ return len(self._dataloader)
diff --git a/video_llama/datasets/datasets/laion_dataset.py b/video_llama/datasets/datasets/laion_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1be30abb188e1afad6fe678ccbb367931a2b3d26
--- /dev/null
+++ b/video_llama/datasets/datasets/laion_dataset.py
@@ -0,0 +1,31 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import webdataset as wds
+from video_llama.datasets.datasets.base_dataset import BaseDataset
+
+
+class LaionDataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, location):
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
+
+ self.inner_dataset = wds.DataPipeline(
+ wds.ResampledShards(location),
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
+ wds.shuffle(1000, handler=wds.warn_and_continue),
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
+ )
+
+ def to_dict(self, sample):
+ return {
+ "image": sample[0],
+ "text_input": self.text_processor(sample[1]["caption"]),
+ }
+
diff --git a/video_llama/datasets/datasets/llava_instruct_dataset.py b/video_llama/datasets/datasets/llava_instruct_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..105e0981581b7934c5df2bc53ecf03142cc4c969
--- /dev/null
+++ b/video_llama/datasets/datasets/llava_instruct_dataset.py
@@ -0,0 +1,228 @@
+import os
+from video_llama.datasets.datasets.base_dataset import BaseDataset
+from video_llama.datasets.datasets.caption_datasets import CaptionDataset
+import pandas as pd
+import decord
+from decord import VideoReader
+import random
+import torch
+from torch.utils.data.dataloader import default_collate
+from PIL import Image
+from typing import Dict, Optional, Sequence
+import transformers
+import pathlib
+import json
+from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
+from video_llama.conversation.conversation_video import Conversation,SeparatorStyle
+DEFAULT_IMAGE_PATCH_TOKEN = ''
+DEFAULT_IMAGE_TOKEN = ""
+import copy
+IGNORE_INDEX = -100
+image_conversation = Conversation(
+ system="",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+IGNORE_INDEX = -100
+
+class Instruct_Dataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image'):
+ """
+ vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
+ ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
+ split (string): val or test
+ """
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
+
+ data_path = pathlib.Path(ann_root)
+ with data_path.open(encoding='utf-8') as f:
+ self.annotation = json.load(f)
+
+ self.vis_root = vis_root
+ self.resize_size = 224
+ self.num_frm = 8
+ self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ self.num_video_query_token = num_video_query_token
+ self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
+
+ self.transform = AlproVideoTrainProcessor(
+ image_size=self.resize_size, n_frms = self.num_frm
+ ).transform
+ self.data_type = data_type
+
+ def _get_image_path(self, sample):
+ rel_video_fp ='COCO_train2014_' + sample['image']
+ full_video_fp = os.path.join(self.vis_root, rel_video_fp)
+ return full_video_fp
+
+ def __getitem__(self, index):
+ num_retries = 10 # skip error videos
+ for _ in range(num_retries):
+ try:
+ sample = self.annotation[index]
+
+ image_path = self._get_image_path(sample)
+ conversation_list = sample['conversations']
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ # text = self.text_processor(text)
+ sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token)
+ data_dict = preprocess(
+ sources,
+ self.tokenizer)
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ # image exist in the data
+ data_dict['image'] = image
+ except:
+ print(f"Failed to load examples with image: {image_path}. "
+ f"Will randomly sample an example as a replacement.")
+ index = random.randint(0, len(self) - 1)
+ continue
+ break
+ else:
+ raise RuntimeError(f"Failed to fetch image after {num_retries} retries.")
+ # "image_id" is kept to stay compatible with the COCO evaluation format
+ return {
+ "image": image,
+ "text_input": data_dict["input_ids"],
+ "labels": data_dict["labels"],
+ "type":'image',
+ }
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def collater(self, instances):
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("text_input", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'image' in instances[0]:
+ images = [instance['image'] for instance in instances]
+ if all(x is not None and x.shape == images[0].shape for x in images):
+ batch['images'] = torch.stack(images)
+ else:
+ batch['images'] = images
+ batch['conv_type'] = 'multi'
+ return batch
+
+
+def preprocess_multimodal(
+ conversation_list: Sequence[str],
+ multimodal_cfg: dict,
+ cur_token_len: int,
+) -> Dict:
+ # 将conversational list中
+ is_multimodal = True
+ # image_token_len = multimodal_cfg['image_token_len']
+ image_token_len = cur_token_len
+
+ for sentence in conversation_list:
+ replace_token = ''+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len+'/'
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return [conversation_list]
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "###"
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = image_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = image_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
+ sentence["value"] + END_SIGNAL)
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+def _tokenize_fn(strings: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=512,
+ truncation=True,
+ ) for text in strings
+ ]
+ input_ids = labels = [
+ tokenized.input_ids[0] for tokenized in tokenized_list
+ ]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{image_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+ # tokenize conversations
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
+ tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
diff --git a/video_llama/datasets/datasets/video_instruct_dataset.py b/video_llama/datasets/datasets/video_instruct_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7de6e20d30d9b0d7280d706636e9849b7f02618c
--- /dev/null
+++ b/video_llama/datasets/datasets/video_instruct_dataset.py
@@ -0,0 +1,253 @@
+import os
+from video_llama.datasets.datasets.base_dataset import BaseDataset
+from video_llama.datasets.datasets.caption_datasets import CaptionDataset
+import pandas as pd
+import decord
+from decord import VideoReader
+import random
+import torch
+from torch.utils.data.dataloader import default_collate
+from PIL import Image
+from typing import Dict, Optional, Sequence
+import transformers
+import pathlib
+import json
+from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
+import copy
+from video_llama.processors import transforms_video,AlproVideoTrainProcessor
+from torchvision import transforms
+from video_llama.processors.video_processor import ToTHWC,ToUint8,load_video
+from video_llama.conversation.conversation_video import Conversation,SeparatorStyle
+
+DEFAULT_IMAGE_PATCH_TOKEN = ''
+video_conversation = Conversation(
+ system="",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+IGNORE_INDEX = -100
+
+class Video_Instruct_Dataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'video'):
+ """
+ vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
+ ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
+ split (string): val or test
+ """
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
+
+ data_path = pathlib.Path(ann_root)
+ with data_path.open(encoding='utf-8') as f:
+ self.annotation = json.load(f)
+
+ self.num_video_query_token = num_video_query_token
+ self.vis_root = vis_root
+ self.resize_size = 224
+ self.num_frm = 8
+ self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
+
+ self.transform = AlproVideoTrainProcessor(
+ image_size=self.resize_size, n_frms = self.num_frm
+ ).transform
+ self.data_type = data_type
+
+ def _get_video_path(self, sample):
+ rel_video_fp = sample['video']
+ full_video_fp = os.path.join(self.vis_root, rel_video_fp)
+ return full_video_fp
+
+ def __getitem__(self, index):
+ num_retries = 10 # skip error videos
+ for _ in range(num_retries):
+ try:
+ sample = self.annotation[index]
+
+ video_path = self._get_video_path(sample)
+ conversation_list = sample['QA']
+
+ video, msg = load_video(
+ video_path=video_path,
+ n_frms=self.num_frm,
+ height=self.resize_size,
+ width=self.resize_size,
+ sampling ="uniform", return_msg = True
+ )
+ video = self.transform(video)
+ if 'cn' in self.data_type:
+ msg = ""
+ # 添加视频,以及msg到convsation list 0
+ sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token,msg = msg)
+ new_sources = convert_source_vicuna_format(sources)
+
+ data_dict = preprocess(
+ new_sources,
+ self.tokenizer)
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+ # image exist in the data
+ data_dict['image'] = video
+ except:
+ print(f"Failed to load examples with video: {video_path}. "
+ f"Will randomly sample an example as a replacement.")
+ index = random.randint(0, len(self) - 1)
+ continue
+ break
+ else:
+ raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
+ # "image_id" is kept to stay compatible with the COCO evaluation format
+ return {
+ "image": video,
+ "text_input": data_dict["input_ids"],
+ "labels": data_dict["labels"],
+ "type":'video',
+ }
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def collater(self, instances):
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("text_input", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'image' in instances[0]:
+ images = [instance['image'] for instance in instances]
+ if all(x is not None and x.shape == images[0].shape for x in images):
+ batch['images'] = torch.stack(images)
+ else:
+ batch['images'] = images
+ batch['conv_type'] = 'multi'
+ return batch
+
+def convert_source_vicuna_format(sources):
+ new_sources = []
+ for source in sources:
+ new_source = []
+ for i, sentence in enumerate(source):
+ role_0_msg = sentence['q']
+ role_1_msg = sentence['a']
+ new_source.append({
+ 'from':'human',
+ 'value': role_0_msg,
+ })
+ new_source.append({
+ 'from':'gpt',
+ 'value': role_1_msg,
+ })
+ new_sources.append(new_source)
+ return new_sources
+
+def preprocess_multimodal(
+ conversation_list: Sequence[str],
+ multimodal_cfg: dict,
+ cur_token_len: int,
+ msg=''
+) -> Dict:
+ # 将conversational list中
+ is_multimodal = True
+ # image_token_len = multimodal_cfg['image_token_len']
+ image_token_len = cur_token_len
+ conversation_list[0]["q"] = " " + msg + conversation_list[0]["q"]
+ return [conversation_list]
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "###"
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = video_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = video_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
+ sentence["value"] + END_SIGNAL)
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+def _tokenize_fn(strings: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=512,
+ truncation=True,
+ ) for text in strings
+ ]
+ input_ids = labels = [
+ tokenized.input_ids[0] for tokenized in tokenized_list
+ ]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{video_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+ # tokenize conversations
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
+ tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
diff --git a/video_llama/datasets/datasets/webvid_datasets.py b/video_llama/datasets/datasets/webvid_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaf6b9d6dff0d96b04d40a40c0051527f7d01842
--- /dev/null
+++ b/video_llama/datasets/datasets/webvid_datasets.py
@@ -0,0 +1,122 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+from video_llama.datasets.datasets.base_dataset import BaseDataset
+from video_llama.datasets.datasets.caption_datasets import CaptionDataset
+import pandas as pd
+import decord
+from decord import VideoReader
+import random
+import torch
+from torch.utils.data.dataloader import default_collate
+class WebvidDataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_root):
+ """
+ vis_root (string): Root directory of video (e.g. webvid_eval/video/)
+ ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
+ split (string): val or test
+ """
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
+
+
+ # 读取一个路径下所有的
+
+ ts_df = []
+ for file_name in os.listdir(ann_root):
+ if file_name.endswith('.csv'):
+ df = pd.read_csv(os.path.join(ann_root, file_name))
+ ts_df.append(df)
+
+ merged_df = pd.concat(ts_df)
+ self.annotation = merged_df
+ self.vis_root = vis_root
+ self.resize_size = 224
+ self.num_frm = 8
+ self.frm_sampling_strategy = 'headtail'
+
+ def _get_video_path(self, sample):
+ rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
+ full_video_fp = os.path.join(self.vis_root, rel_video_fp)
+ return full_video_fp
+
+ def __getitem__(self, index):
+ num_retries = 10 # skip error videos
+ for _ in range(num_retries):
+ sample = self.annotation.iloc[index]
+ sample_dict = sample.to_dict()
+ video_id = sample_dict['videoid']
+
+ if 'name' in sample_dict.keys():
+ text = sample_dict['name'].strip()
+ else:
+ raise NotImplementedError("Un-supported text annotation format.")
+
+ # fetch video
+ video_path = self._get_video_path(sample_dict)
+ # if os.path.exists(video_path):
+ try:
+ video = self.vis_processor(video_path)
+ except:
+ print(f"Failed to load examples with video: {video_path}. "
+ f"Will randomly sample an example as a replacement.")
+ index = random.randint(0, len(self) - 1)
+ continue
+ caption = self.text_processor(text)
+
+ # print(video.size())
+ if video is None or caption is None \
+ or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]):
+ print(f"Failed to load examples with video: {video_path}. "
+ f"Will randomly sample an example as a replacement.")
+ index = random.randint(0, len(self) - 1)
+ continue
+ else:
+ break
+ else:
+ raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
+ # "image_id" is kept to stay compatible with the COCO evaluation format
+ return {
+ "image": video,
+ "text_input": caption,
+ "type":'video',
+ }
+
+ def __len__(self):
+ return len(self.annotation)
+
+ # def collater(self, samples):
+ # new_result = {}
+ # new_result['image'] = default_collate( [sample["image"] for sample in samples])
+ # new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples])
+ # return new_result
+
+class WebvidDatasetEvalDataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ """
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ vname = ann["video"]
+ video_path = os.path.join(self.vis_root, vname)
+
+ video = self.vis_processor(video_path)
+
+ return {
+ "video": video,
+ "image_id": ann["image_id"],
+ "instance_id": ann["instance_id"],
+ }
+
+
diff --git a/video_llama/models/Qformer.py b/video_llama/models/Qformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4902165ec6574d89f04cbeb2141b018278324ca6
--- /dev/null
+++ b/video_llama/models/Qformer.py
@@ -0,0 +1,1217 @@
+"""
+Adapted from salesforce@LAVIS. Below is the original copyright:
+ * Copyright (c) 2023, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+"""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Dict, Any
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
+ )
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size
+ )
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
+ )
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ position_ids=None,
+ query_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ seq_length = input_ids.size()[1]
+ else:
+ seq_length = 0
+
+ if position_ids is None:
+ position_ids = self.position_ids[
+ :, past_key_values_length : seq_length + past_key_values_length
+ ].clone()
+
+ if input_ids is not None:
+ embeddings = self.word_embeddings(input_ids)
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = embeddings + position_embeddings
+
+ if query_embeds is not None:
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
+ else:
+ embeddings = query_embeds
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+ config, "embedding_size"
+ ):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
+ )
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ mixed_query_layer = self.query(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(-1, 1)
+ position_ids_r = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(
+ distance + self.max_position_embeddings - 1
+ )
+ positional_embedding = positional_embedding.to(
+ dtype=query_layer.dtype
+ ) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ relative_position_scores_key = torch.einsum(
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
+ )
+ attention_scores = (
+ attention_scores
+ + relative_position_scores_query
+ + relative_position_scores_key
+ )
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
+ )
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.self.num_attention_heads,
+ self.self.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = (
+ self.self.attention_head_size * self.self.num_attention_heads
+ )
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[
+ 1:
+ ] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if (
+ self.config.add_cross_attention
+ and layer_num % self.config.cross_attention_freq == 0
+ ):
+ self.crossattention = BertAttention(
+ config, is_cross_attention=self.config.add_cross_attention
+ )
+ self.has_cross_attention = True
+ else:
+ self.has_cross_attention = False
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ self.intermediate_query = BertIntermediate(config)
+ self.output_query = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ query_length=0,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = (
+ past_key_value[:2] if past_key_value is not None else None
+ )
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:-1]
+
+ present_key_value = self_attention_outputs[-1]
+
+ if query_length > 0:
+ query_attention_output = attention_output[:, :query_length, :]
+
+ if self.has_cross_attention:
+ assert (
+ encoder_hidden_states is not None
+ ), "encoder_hidden_states must be given for cross-attention layers"
+ cross_attention_outputs = self.crossattention(
+ query_attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ query_attention_output = cross_attention_outputs[0]
+ outputs = (
+ outputs + cross_attention_outputs[1:-1]
+ ) # add cross attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk_query,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ query_attention_output,
+ )
+ if attention_output.shape[1] > query_length:
+ layer_output_text = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output[:, query_length:, :],
+ )
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
+ else:
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+ def feed_forward_chunk_query(self, attention_output):
+ intermediate_output = self.intermediate_query(attention_output)
+ layer_output = self.output_query(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ query_length=0,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = (
+ () if output_attentions and self.config.add_cross_attention else None
+ )
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(
+ *inputs, past_key_value, output_attentions, query_length
+ )
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ query_length,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=False):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: Tensor,
+ input_shape: Tuple[int],
+ device: device,
+ is_decoder: bool,
+ has_query: bool = False,
+ ) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = (
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
+ <= seq_ids[None, :, None]
+ )
+
+ # add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ if has_query: # UniLM style attention mask
+ causal_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, prefix_seq_len, seq_length),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=1,
+ )
+ causal_mask = torch.cat(
+ [
+ torch.ones(
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+ extended_attention_mask = (
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ )
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=self.dtype
+ ) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if input_ids is None:
+ assert (
+ query_embeds is not None
+ ), "You have to specify query_embeds when input_ids is None"
+
+ # past_key_values_length
+ past_key_values_length = (
+ past_key_values[0][0].shape[2] - self.config.query_length
+ if past_key_values is not None
+ else 0
+ )
+
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ query_embeds=query_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ input_shape = embedding_output.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = embedding_output.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ ((batch_size, seq_length + past_key_values_length)), device=device
+ )
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if is_decoder:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask,
+ input_ids.shape,
+ device,
+ is_decoder,
+ has_query=(query_embeds is not None),
+ )
+ else:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, device, is_decoder
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
+ 0
+ ].size()
+ else:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
+ ]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ query_length=query_length,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = (
+ self.pooler(sequence_output) if self.pooler is not None else None
+ )
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=True,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction="mean",
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ if labels is not None:
+ use_cache = False
+ if past_key_values is not None:
+ query_embeds = None
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ sequence_output = outputs[0]
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+ if reduction == "none":
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
+ ):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "query_embeds": query_embeds,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx) for past_state in layer_past
+ ),
+ )
+ return reordered_past
+
+
+class BertForMaskedLM(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return (
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+ )
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/video_llama/models/__init__.py b/video_llama/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ddc54288cdc9d7e75b71da3f8597052f4f4837c
--- /dev/null
+++ b/video_llama/models/__init__.py
@@ -0,0 +1,201 @@
+"""
+Adapted from salesforce@LAVIS Vision-CAIR@MiniGPT-4. Below is the original copyright:
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import torch
+from omegaconf import OmegaConf
+
+from video_llama.common.registry import registry
+from video_llama.models.base_model import BaseModel
+from video_llama.models.blip2 import Blip2Base
+from video_llama.models.video_llama import VideoLLAMA
+from video_llama.processors.base_processor import BaseProcessor
+
+
+__all__ = [
+ "load_model",
+ "BaseModel",
+ "Blip2Base",
+ "VideoLLAMA"
+]
+
+
+def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
+ """
+ Load supported models.
+
+ To list all available models and types in registry:
+ >>> from video_llama.models import model_zoo
+ >>> print(model_zoo)
+
+ Args:
+ name (str): name of the model.
+ model_type (str): type of the model.
+ is_eval (bool): whether the model is in eval mode. Default: False.
+ device (str): device to use. Default: "cpu".
+ checkpoint (str): path or to checkpoint. Default: None.
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
+
+ Returns:
+ model (torch.nn.Module): model.
+ """
+
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
+
+ if checkpoint is not None:
+ model.load_checkpoint(checkpoint)
+
+ if is_eval:
+ model.eval()
+
+ if device == "cpu":
+ model = model.float()
+
+ return model.to(device)
+
+
+def load_preprocess(config):
+ """
+ Load preprocessor configs and construct preprocessors.
+
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
+
+ Args:
+ config (dict): preprocessor configs.
+
+ Returns:
+ vis_processors (dict): preprocessors for visual inputs.
+ txt_processors (dict): preprocessors for text inputs.
+
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
+ """
+
+ def _build_proc_from_cfg(cfg):
+ return (
+ registry.get_processor_class(cfg.name).from_config(cfg)
+ if cfg is not None
+ else BaseProcessor()
+ )
+
+ vis_processors = dict()
+ txt_processors = dict()
+
+ vis_proc_cfg = config.get("vis_processor")
+ txt_proc_cfg = config.get("text_processor")
+
+ if vis_proc_cfg is not None:
+ vis_train_cfg = vis_proc_cfg.get("train")
+ vis_eval_cfg = vis_proc_cfg.get("eval")
+ else:
+ vis_train_cfg = None
+ vis_eval_cfg = None
+
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
+
+ if txt_proc_cfg is not None:
+ txt_train_cfg = txt_proc_cfg.get("train")
+ txt_eval_cfg = txt_proc_cfg.get("eval")
+ else:
+ txt_train_cfg = None
+ txt_eval_cfg = None
+
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
+
+ return vis_processors, txt_processors
+
+
+def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
+ """
+ Load model and its related preprocessors.
+
+ List all available models and types in registry:
+ >>> from video_llama.models import model_zoo
+ >>> print(model_zoo)
+
+ Args:
+ name (str): name of the model.
+ model_type (str): type of the model.
+ is_eval (bool): whether the model is in eval mode. Default: False.
+ device (str): device to use. Default: "cpu".
+
+ Returns:
+ model (torch.nn.Module): model.
+ vis_processors (dict): preprocessors for visual inputs.
+ txt_processors (dict): preprocessors for text inputs.
+ """
+ model_cls = registry.get_model_class(name)
+
+ # load model
+ model = model_cls.from_pretrained(model_type=model_type)
+
+ if is_eval:
+ model.eval()
+
+ # load preprocess
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
+ if cfg is not None:
+ preprocess_cfg = cfg.preprocess
+
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
+ else:
+ vis_processors, txt_processors = None, None
+ logging.info(
+ f"""No default preprocess for model {name} ({model_type}).
+ This can happen if the model is not finetuned on downstream datasets,
+ or it is not intended for direct use without finetuning.
+ """
+ )
+
+ if device == "cpu" or device == torch.device("cpu"):
+ model = model.float()
+
+ return model.to(device), vis_processors, txt_processors
+
+
+class ModelZoo:
+ """
+ A utility class to create string representation of available model architectures and types.
+
+ >>> from video_llama.models import model_zoo
+ >>> # list all available models
+ >>> print(model_zoo)
+ >>> # show total number of models
+ >>> print(len(model_zoo))
+ """
+
+ def __init__(self) -> None:
+ self.model_zoo = {
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
+ for k, v in registry.mapping["model_name_mapping"].items()
+ }
+
+ def __str__(self) -> str:
+ return (
+ "=" * 50
+ + "\n"
+ + f"{'Architectures':<30} {'Types'}\n"
+ + "=" * 50
+ + "\n"
+ + "\n".join(
+ [
+ f"{name:<30} {', '.join(types)}"
+ for name, types in self.model_zoo.items()
+ ]
+ )
+ )
+
+ def __iter__(self):
+ return iter(self.model_zoo.items())
+
+ def __len__(self):
+ return sum([len(v) for v in self.model_zoo.values()])
+
+
+model_zoo = ModelZoo()
diff --git a/video_llama/models/base_model.py b/video_llama/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..272ddd15129a83b6a5a0063553f512faca1f5612
--- /dev/null
+++ b/video_llama/models/base_model.py
@@ -0,0 +1,248 @@
+"""
+Adapted from salesforce@LAVIS. Below is the original copyright:
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+from video_llama.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
+from video_llama.common.utils import get_abs_path, is_url
+from omegaconf import OmegaConf
+
+
+class BaseModel(nn.Module):
+ """Base class for models."""
+
+ def __init__(self):
+ super().__init__()
+
+ @property
+ def device(self):
+ return list(self.parameters())[0].device
+
+ def load_checkpoint(self, url_or_filename):
+ """
+ Load from a finetuned checkpoint.
+
+ This should expect no mismatch in the model keys and the checkpoint keys.
+ """
+
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location="cpu")
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ if "model" in checkpoint.keys():
+ state_dict = checkpoint["model"]
+ else:
+ state_dict = checkpoint
+
+ msg = self.load_state_dict(state_dict, strict=False)
+
+ logging.info("Missing keys {}".format(msg.missing_keys))
+ logging.info("load checkpoint from %s" % url_or_filename)
+
+ return msg
+
+ @classmethod
+ def from_pretrained(cls, model_type):
+ """
+ Build a pretrained model from default configuration file, specified by model_type.
+
+ Args:
+ - model_type (str): model type, specifying architecture and checkpoints.
+
+ Returns:
+ - model (nn.Module): pretrained or finetuned model, depending on the configuration.
+ """
+ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
+ model = cls.from_config(model_cfg)
+
+ return model
+
+ @classmethod
+ def default_config_path(cls, model_type):
+ assert (
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
+ ), "Unknown model type {}".format(model_type)
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
+
+ def load_checkpoint_from_config(self, cfg, **kwargs):
+ """
+ Load checkpoint as specified in the config file.
+
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
+ When loading the pretrained model, each task-specific architecture may define their
+ own load_from_pretrained() method.
+ """
+ load_finetuned = cfg.get("load_finetuned", True)
+ if load_finetuned:
+ finetune_path = cfg.get("finetuned", None)
+ assert (
+ finetune_path is not None
+ ), "Found load_finetuned is True, but finetune_path is None."
+ self.load_checkpoint(url_or_filename=finetune_path)
+ else:
+ # load pre-trained weights
+ pretrain_path = cfg.get("pretrained", None)
+ assert "Found load_finetuned is False, but pretrain_path is None."
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
+
+ def before_evaluation(self, **kwargs):
+ pass
+
+ def show_n_params(self, return_str=True):
+ tot = 0
+ for p in self.parameters():
+ w = 1
+ for x in p.shape:
+ w *= x
+ tot += w
+ if return_str:
+ if tot >= 1e6:
+ return "{:.1f}M".format(tot / 1e6)
+ else:
+ return "{:.1f}K".format(tot / 1e3)
+ else:
+ return tot
+
+
+class BaseEncoder(nn.Module):
+ """
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward_features(self, samples, **kwargs):
+ raise NotImplementedError
+
+ @property
+ def device(self):
+ return list(self.parameters())[0].device
+
+
+class SharedQueueMixin:
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
+ # gather keys before updating queue
+ image_feats = concat_all_gather(image_feat)
+ text_feats = concat_all_gather(text_feat)
+
+ batch_size = image_feats.shape[0]
+
+ ptr = int(self.queue_ptr)
+ assert self.queue_size % batch_size == 0 # for simplicity
+
+ # replace the keys at ptr (dequeue and enqueue)
+ self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
+ self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
+
+ if idxs is not None:
+ idxs = concat_all_gather(idxs)
+ self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
+
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
+ self.queue_ptr[0] = ptr
+
+
+class MomentumDistilationMixin:
+ @torch.no_grad()
+ def copy_params(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(
+ model_pair[0].parameters(), model_pair[1].parameters()
+ ):
+ param_m.data.copy_(param.data) # initialize
+ param_m.requires_grad = False # not update by gradient
+
+ @torch.no_grad()
+ def _momentum_update(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(
+ model_pair[0].parameters(), model_pair[1].parameters()
+ ):
+ param_m.data = param_m.data * self.momentum + param.data * (
+ 1.0 - self.momentum
+ )
+
+
+class GatherLayer(torch.autograd.Function):
+ """
+ Gather tensors from all workers with support for backward propagation:
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ output = [
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
+ ]
+ torch.distributed.all_gather(output, x)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ all_gradients = torch.stack(grads)
+ torch.distributed.all_reduce(all_gradients)
+ return all_gradients[torch.distributed.get_rank()]
+
+
+def all_gather_with_grad(tensors):
+ """
+ Performs all_gather operation on the provided tensors.
+ Graph remains connected for backward grad computation.
+ """
+ # Queue the gathered tensors
+ world_size = torch.distributed.get_world_size()
+ # There is no need for reduction in the single-proc case
+ if world_size == 1:
+ return tensors
+
+ # tensor_all = GatherLayer.apply(tensors)
+ tensor_all = GatherLayer.apply(tensors)
+
+ return torch.cat(tensor_all, dim=0)
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ # if use distributed training
+ if not is_dist_avail_and_initialized():
+ return tensor
+
+ tensors_gather = [
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+ ]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+def tile(x, dim, n_tile):
+ init_dim = x.size(dim)
+ repeat_idx = [1] * x.dim()
+ repeat_idx[dim] = n_tile
+ x = x.repeat(*(repeat_idx))
+ order_index = torch.LongTensor(
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
+ )
+ return torch.index_select(x, dim, order_index.to(x.device))
diff --git a/video_llama/models/blip2.py b/video_llama/models/blip2.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6a42b5ee1016acf4ebeb82d2c04fe69fe2364f5
--- /dev/null
+++ b/video_llama/models/blip2.py
@@ -0,0 +1,222 @@
+"""
+Adapted from salesforce@LAVIS. Below is the original copyright:
+ Copyright (c) 2023, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+import contextlib
+import logging
+import os
+import time
+import datetime
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import torch.nn.functional as F
+
+import video_llama.common.dist_utils as dist_utils
+from video_llama.common.dist_utils import download_cached_file
+from video_llama.common.utils import is_url
+from video_llama.common.logger import MetricLogger
+from video_llama.models.base_model import BaseModel
+from video_llama.models.Qformer import BertConfig, BertLMHeadModel
+from video_llama.models.eva_vit import create_eva_vit_g
+from transformers import BertTokenizer
+
+
+class Blip2Base(BaseModel):
+ @classmethod
+ def init_tokenizer(cls):
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
+ return tokenizer
+
+ def maybe_autocast(self, dtype=torch.float16):
+ # if on cpu, don't use autocast
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
+ enable_autocast = self.device != torch.device("cpu")
+
+ if enable_autocast:
+ return torch.cuda.amp.autocast(dtype=dtype)
+ else:
+ return contextlib.nullcontext()
+
+ @classmethod
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
+ encoder_config.encoder_width = vision_width
+ # insert cross-attention layer every other block
+ encoder_config.add_cross_attention = True
+ encoder_config.cross_attention_freq = cross_attention_freq
+ encoder_config.query_length = num_query_token
+ Qformer = BertLMHeadModel(config=encoder_config)
+ query_tokens = nn.Parameter(
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
+ )
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+ return Qformer, query_tokens
+
+ @classmethod
+ def init_vision_encoder(
+ cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
+ ):
+ assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
+ visual_encoder = create_eva_vit_g(
+ img_size, drop_path_rate, use_grad_checkpoint, precision
+ )
+
+ ln_vision = LayerNorm(visual_encoder.num_features)
+ return visual_encoder, ln_vision
+
+ def load_from_pretrained(self, url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location="cpu")
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ state_dict = checkpoint["model"]
+
+ msg = self.load_state_dict(state_dict, strict=False)
+
+ # logging.info("Missing keys {}".format(msg.missing_keys))
+ logging.info("load checkpoint from %s" % url_or_filename)
+
+ return msg
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+def compute_sim_matrix(model, data_loader, **kwargs):
+ k_test = kwargs.pop("k_test")
+
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation:"
+
+ logging.info("Computing features for evaluation...")
+ start_time = time.time()
+
+ texts = data_loader.dataset.text
+ num_text = len(texts)
+ text_bs = 256
+ text_ids = []
+ text_embeds = []
+ text_atts = []
+ for i in range(0, num_text, text_bs):
+ text = texts[i : min(num_text, i + text_bs)]
+ text_input = model.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=35,
+ return_tensors="pt",
+ ).to(model.device)
+ text_feat = model.forward_text(text_input)
+ text_embed = F.normalize(model.text_proj(text_feat))
+ text_embeds.append(text_embed)
+ text_ids.append(text_input.input_ids)
+ text_atts.append(text_input.attention_mask)
+
+ text_embeds = torch.cat(text_embeds, dim=0)
+ text_ids = torch.cat(text_ids, dim=0)
+ text_atts = torch.cat(text_atts, dim=0)
+
+ vit_feats = []
+ image_embeds = []
+ for samples in data_loader:
+ image = samples["image"]
+
+ image = image.to(model.device)
+ image_feat, vit_feat = model.forward_image(image)
+ image_embed = model.vision_proj(image_feat)
+ image_embed = F.normalize(image_embed, dim=-1)
+
+ vit_feats.append(vit_feat.cpu())
+ image_embeds.append(image_embed)
+
+ vit_feats = torch.cat(vit_feats, dim=0)
+ image_embeds = torch.cat(image_embeds, dim=0)
+
+ sims_matrix = []
+ for image_embed in image_embeds:
+ sim_q2t = image_embed @ text_embeds.t()
+ sim_i2t, _ = sim_q2t.max(0)
+ sims_matrix.append(sim_i2t)
+ sims_matrix = torch.stack(sims_matrix, dim=0)
+
+ score_matrix_i2t = torch.full(
+ (len(data_loader.dataset.image), len(texts)), -100.0
+ ).to(model.device)
+
+ num_tasks = dist_utils.get_world_size()
+ rank = dist_utils.get_rank()
+ step = sims_matrix.size(0) // num_tasks + 1
+ start = rank * step
+ end = min(sims_matrix.size(0), start + step)
+
+ for i, sims in enumerate(
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
+ ):
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
+ image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
+ score = model.compute_itm(
+ image_inputs=image_inputs,
+ text_ids=text_ids[topk_idx],
+ text_atts=text_atts[topk_idx],
+ ).float()
+ score_matrix_i2t[start + i, topk_idx] = score + topk_sim
+
+ sims_matrix = sims_matrix.t()
+ score_matrix_t2i = torch.full(
+ (len(texts), len(data_loader.dataset.image)), -100.0
+ ).to(model.device)
+
+ step = sims_matrix.size(0) // num_tasks + 1
+ start = rank * step
+ end = min(sims_matrix.size(0), start + step)
+
+ for i, sims in enumerate(
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
+ ):
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
+ image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
+ score = model.compute_itm(
+ image_inputs=image_inputs,
+ text_ids=text_ids[start + i].repeat(k_test, 1),
+ text_atts=text_atts[start + i].repeat(k_test, 1),
+ ).float()
+ score_matrix_t2i[start + i, topk_idx] = score + topk_sim
+
+ if dist_utils.is_dist_avail_and_initialized():
+ dist.barrier()
+ torch.distributed.all_reduce(
+ score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
+ )
+ torch.distributed.all_reduce(
+ score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
+ )
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logging.info("Evaluation time {}".format(total_time_str))
+
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
diff --git a/video_llama/models/blip2_outputs.py b/video_llama/models/blip2_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d83a0556e6c5c3c0a603279f318605ae25d6d5
--- /dev/null
+++ b/video_llama/models/blip2_outputs.py
@@ -0,0 +1,111 @@
+"""
+Adapted from salesforce@LAVIS. Below is the original copyright:
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+from transformers.modeling_outputs import (
+ ModelOutput,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+)
+
+
+@dataclass
+class BlipSimilarity(ModelOutput):
+ sim_i2t: torch.FloatTensor = None
+ sim_t2i: torch.FloatTensor = None
+
+ sim_i2t_m: Optional[torch.FloatTensor] = None
+ sim_t2i_m: Optional[torch.FloatTensor] = None
+
+ sim_i2t_targets: Optional[torch.FloatTensor] = None
+ sim_t2i_targets: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class BlipIntermediateOutput(ModelOutput):
+ """
+ Data class for intermediate outputs of BLIP models.
+
+ image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
+ text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
+
+ image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
+ text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
+
+ encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
+ encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
+
+ decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
+ decoder_labels (torch.LongTensor): labels for the captioning loss.
+
+ itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
+ itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
+
+ """
+
+ # uni-modal features
+ image_embeds: torch.FloatTensor = None
+ text_embeds: Optional[torch.FloatTensor] = None
+
+ image_embeds_m: Optional[torch.FloatTensor] = None
+ text_embeds_m: Optional[torch.FloatTensor] = None
+
+ # intermediate outputs of multimodal encoder
+ encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
+ encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
+
+ itm_logits: Optional[torch.FloatTensor] = None
+ itm_labels: Optional[torch.LongTensor] = None
+
+ # intermediate outputs of multimodal decoder
+ decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
+ decoder_labels: Optional[torch.LongTensor] = None
+
+
+@dataclass
+class BlipOutput(ModelOutput):
+ # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
+ sims: Optional[BlipSimilarity] = None
+
+ intermediate_output: BlipIntermediateOutput = None
+
+ loss: Optional[torch.FloatTensor] = None
+
+ loss_itc: Optional[torch.FloatTensor] = None
+
+ loss_itm: Optional[torch.FloatTensor] = None
+
+ loss_lm: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class BlipOutputFeatures(ModelOutput):
+ """
+ Data class of features from BlipFeatureExtractor.
+
+ Args:
+ image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
+ image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
+ text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
+ text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
+
+ The first embedding or feature is for the [CLS] token.
+
+ Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ image_embeds_proj: Optional[torch.FloatTensor] = None
+
+ text_embeds: Optional[torch.FloatTensor] = None
+ text_embeds_proj: Optional[torch.FloatTensor] = None
+
+ multimodal_embeds: Optional[torch.FloatTensor] = None
diff --git a/video_llama/models/eva_vit.py b/video_llama/models/eva_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..864bffd0c2ffad18c642ce55e9d0ccf44fbe5a56
--- /dev/null
+++ b/video_llama/models/eva_vit.py
@@ -0,0 +1,442 @@
+# Based on EVA, BEIT, timm and DeiT code bases
+# https://github.com/baaivision/EVA
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/microsoft/unilm/tree/master/beit
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+
+from video_llama.common.dist_utils import download_cached_file
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ **kwargs
+ }
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ # commit this for the orignal BERT implement
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if init_values is not None and init_values > 0:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rel_pos_bias=None):
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+ def forward(self):
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
+ super().__init__()
+ self.image_size = img_size
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ else:
+ self.pos_embed = None
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+ else:
+ self.rel_pos_bias = None
+ self.use_checkpoint = use_checkpoint
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.use_rel_pos_bias = use_rel_pos_bias
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
+ for i in range(depth)])
+# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ # trunc_normal_(self.mask_token, std=.02)
+# if isinstance(self.head, nn.Linear):
+# trunc_normal_(self.head.weight, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+# if isinstance(self.head, nn.Linear):
+# self.head.weight.data.mul_(init_scale)
+# self.head.bias.data.mul_(init_scale)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
+ else:
+ x = blk(x, rel_pos_bias)
+ return x
+# x = self.norm(x)
+
+# if self.fc_norm is not None:
+# t = x[:, 1:, :]
+# return self.fc_norm(t.mean(1))
+# else:
+# return x[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+# x = self.head(x)
+ return x
+
+ def get_intermediate_layers(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ features = []
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ x = blk(x, rel_pos_bias)
+ features.append(x)
+
+ return features
+
+
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+
+
+def convert_weights_to_fp16(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+# if isinstance(l, (nn.MultiheadAttention, Attention)):
+# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+# tensor = getattr(l, attr)
+# if tensor is not None:
+# tensor.data = tensor.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
+ model = VisionTransformer(
+ img_size=img_size,
+ patch_size=14,
+ use_mean_pooling=False,
+ embed_dim=1408,
+ depth=39,
+ num_heads=1408//88,
+ mlp_ratio=4.3637,
+ qkv_bias=True,
+ drop_path_rate=drop_path_rate,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ use_checkpoint=use_checkpoint,
+ )
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
+ cached_file = download_cached_file(
+ url, check_hash=False, progress=True
+ )
+ state_dict = torch.load(cached_file, map_location="cpu")
+ interpolate_pos_embed(model,state_dict)
+
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
+# print(incompatible_keys)
+
+ if precision == "fp16":
+# model.to("cuda")
+ convert_weights_to_fp16(model)
+ return model
\ No newline at end of file
diff --git a/video_llama/models/modeling_llama.py b/video_llama/models/modeling_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..12d980e189d902fb1a6d9ea05dc3ca91959b1c8c
--- /dev/null
+++ b/video_llama/models/modeling_llama.py
@@ -0,0 +1,755 @@
+# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+
+""" PyTorch LLaMA model."""
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from transformers.models.llama.configuration_llama import LlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class LlamaRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Build here to make `torch.jit.trace` work.
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ )
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class LlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = LlamaAttention(config=config)
+ self.mlp = LlamaMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LlamaModel):
+ module.gradient_checkpointing = value
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ query_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ if query_embeds is not None:
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
+ batch_size, seq_length, _ = inputs_embeds.shape
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ query_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ query_embeds=query_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+ query_embeds = None
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "query_embeds": query_embeds,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
diff --git a/video_llama/models/video_llama.py b/video_llama/models/video_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..c287887992ab48fdb7306cba2f6703e6b081712c
--- /dev/null
+++ b/video_llama/models/video_llama.py
@@ -0,0 +1,424 @@
+import logging
+import random
+
+import torch
+from torch.cuda.amp import autocast as autocast
+import torch.nn as nn
+
+from video_llama.common.registry import registry
+from video_llama.models.blip2 import Blip2Base, disabled_train
+from video_llama.models.modeling_llama import LlamaForCausalLM
+# from video_llama.models.Qformer import BertEncoder
+from transformers import LlamaTokenizer,BertConfig
+# from transformers.models.bert.modeling_bert import BertEncoder
+import einops
+import copy
+import os
+from video_llama.models.Qformer import BertConfig, BertLMHeadModel
+# from flamingo_pytorch import PerceiverResampler
+@registry.register_model("video_llama")
+class VideoLLAMA(Blip2Base):
+ """
+ BLIP2 GPT-LLAMA model.
+ """
+
+ PRETRAINED_MODEL_CONFIG_DICT = {
+ "pretrain_vicuna": "configs/models/video_llama.yaml",
+ }
+
+ @classmethod
+ def init_video_Qformer(cls, num_query_token, vision_width,num_hidden_layers =2):
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
+ encoder_config.num_hidden_layers = num_hidden_layers
+ encoder_config.encoder_width = vision_width
+ # insert cross-attention layer every other block
+ encoder_config.add_cross_attention = True
+ encoder_config.cross_attention_freq = 1
+ encoder_config.query_length = num_query_token
+ Qformer = BertLMHeadModel(config=encoder_config)
+ query_tokens = nn.Parameter(
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
+ )
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+ return Qformer, query_tokens
+
+ def __init__(
+ self,
+ vit_model="eva_clip_g",
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
+ img_size=224,
+ drop_path_rate=0,
+ use_grad_checkpoint=False,
+ vit_precision="fp16",
+ freeze_vit=True,
+ freeze_qformer=True,
+ num_query_token=32,
+ llama_model="",
+ prompt_path="",
+ prompt_template="",
+ max_txt_len=32,
+ end_sym='\n',
+ low_resource=False, # use 8 bit and put vit in cpu
+ device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
+
+ frozen_llama_proj=True,
+ llama_proj_model='',
+ fusion_header_type= "seqTransf",
+ max_frame_pos= 32,
+ fusion_head_layers = 2,
+ num_video_query_token = 32,
+ ):
+ super().__init__()
+
+ self.tokenizer = self.init_tokenizer()
+ self.low_resource = low_resource
+
+ print('Loading VIT')
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
+ )
+ if freeze_vit:
+ for name, param in self.visual_encoder.named_parameters():
+ param.requires_grad = False
+ self.visual_encoder = self.visual_encoder.eval()
+ self.visual_encoder.train = disabled_train
+ for name, param in self.ln_vision.named_parameters():
+ param.requires_grad = False
+ self.ln_vision = self.ln_vision.eval()
+ self.ln_vision.train = disabled_train
+ logging.info("freeze vision encoder")
+ print('Loading VIT Done')
+
+ print('Loading Q-Former')
+ self.Qformer, self.query_tokens = self.init_Qformer(
+ num_query_token, self.visual_encoder.num_features
+ )
+ self.Qformer.cls = None
+ self.Qformer.bert.embeddings.word_embeddings = None
+ self.Qformer.bert.embeddings.position_embeddings = None
+ for layer in self.Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+ self.load_from_pretrained(url_or_filename=q_former_model)
+
+ if freeze_qformer:
+ for name, param in self.Qformer.named_parameters():
+ param.requires_grad = False
+ self.Qformer = self.Qformer.eval()
+ self.Qformer.train = disabled_train
+ self.query_tokens.requires_grad = False
+ logging.info("freeze Qformer")
+ logging.info('Loading Q-Former Done')
+
+ logging.info('Loading LLAMA Tokenizer')
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False, use_auth_token=os.environ["API_TOKEN"])
+ if self.llama_tokenizer.pad_token is None:
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
+ DEFAULT_IMAGE_PATCH_TOKEN = ''
+ self.llama_tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ self.IMAGE_PATCH_TOKEN_ID = self.llama_tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
+
+ logging.info('Loading LLAMA Model')
+ if self.low_resource:
+ self.llama_model = LlamaForCausalLM.from_pretrained(
+ llama_model,
+ torch_dtype=torch.float16,
+ load_in_8bit=True,
+ device_map={'': device_8bit},
+ use_auth_token=os.environ["API_TOKEN"]
+ )
+ else:
+ self.llama_model = LlamaForCausalLM.from_pretrained(
+ llama_model,
+ torch_dtype=torch.float16,use_auth_token=os.environ["API_TOKEN"]
+ )
+
+ for name, param in self.llama_model.named_parameters():
+ param.requires_grad = False
+ logging.info('Loading LLAMA Done')
+
+
+ logging.info('Loading LLAMA proj')
+ self.llama_proj = nn.Linear(
+ self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
+ )
+ if llama_proj_model:
+ print("load llama proj weight: {}".format(llama_proj_model))
+ llama_proj_weight = torch.load(llama_proj_model, map_location="cpu")
+ msg = model.load_state_dict(llama_proj_weight['model'], strict=False)
+
+ if frozen_llama_proj:
+ # todo frozen llama_proj
+ for name, param in self.llama_proj.named_parameters():
+ param.requires_grad = False
+ logging.info('LLAMA proj is frozen')
+ else:
+ for name, param in self.llama_proj.named_parameters():
+ param.requires_grad = True
+ logging.info('LLAMA proj is not frozen')
+
+ logging.info('Loading llama_proj Done')
+
+ self.max_txt_len = max_txt_len
+ self.end_sym = end_sym
+
+ if prompt_path:
+ with open(prompt_path, 'r') as f:
+ raw_prompts = f.read().splitlines()
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt]
+ self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
+ print('Load {} training prompts'.format(len(self.prompt_list)))
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
+ else:
+ self.prompt_list = []
+
+ self.video_frame_position_embedding = nn.Embedding(max_frame_pos, self.Qformer.config.hidden_size)
+ self.num_video_query_token = num_video_query_token
+ self.video_Qformer,self.video_query_tokens = self.init_video_Qformer(num_query_token = num_video_query_token,\
+ vision_width=self.Qformer.config.hidden_size, num_hidden_layers =2)
+
+ self.video_Qformer.cls = None
+ self.video_Qformer.bert.embeddings.word_embeddings = None
+ self.video_Qformer.bert.embeddings.position_embeddings = None
+ for layer in self.video_Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+
+
+ def vit_to_cpu(self):
+ self.ln_vision.to("cpu")
+ self.ln_vision.float()
+ self.visual_encoder.to("cpu")
+ self.visual_encoder.float()
+
+ def encode_img(self, image):
+ device = image.device
+ # if self.low_resource:
+ # self.vit_to_cpu()
+ # image = image.to("cpu")
+
+ # input shape b,c,t,h,w
+ batch_size,_,time_length,_,_ = image.size()
+ image = einops.rearrange(image, 'b c t h w -> (b t) c h w')
+ with self.maybe_autocast():
+ # embed image features with blip2, out: (b t) q h
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
+
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ # add frame_pos embedding
+ position_ids = torch.arange(time_length, dtype=torch.long, device=query_tokens.device)
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
+ frame_position_embeddings = self.video_frame_position_embedding(position_ids)
+ q_hidden_state = query_output.last_hidden_state
+
+ frame_position_embeddings = frame_position_embeddings.unsqueeze(-2)
+ frame_hidden_state = einops.rearrange(q_hidden_state, '(b t) q h -> b t q h',b=batch_size,t=time_length)
+ frame_hidden_state = frame_position_embeddings + frame_hidden_state
+
+ # frame attention
+ frame_hidden_state = einops.rearrange(frame_hidden_state, 'b t q h -> b (t q) h',b=batch_size,t=time_length)
+ frame_atts = torch.ones(frame_hidden_state.size()[:-1], dtype=torch.long).to(device)
+ video_query_tokens = self.video_query_tokens.expand(frame_hidden_state.shape[0], -1, -1)
+
+ # print('attention')
+ # print(video_query_tokens.size())
+ # print(frame_hidden_state.size())
+ video_query_output = self.video_Qformer.bert(
+ query_embeds=video_query_tokens,
+ encoder_hidden_states=frame_hidden_state,
+ encoder_attention_mask=frame_atts,
+ return_dict=True,
+ )
+ video_hidden = video_query_output.last_hidden_state
+
+ inputs_llama = self.llama_proj(video_hidden)
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image_embeds.device)
+ return inputs_llama, atts_llama
+
+ def prompt_wrap(self, img_embeds, atts_img, prompt):
+ if prompt:
+ batch_size = img_embeds.shape[0]
+ # print(prompt)
+ p_before, p_after = prompt.split('')
+ p_before_tokens = self.llama_tokenizer(
+ p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
+ p_after_tokens = self.llama_tokenizer(
+ p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
+ p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
+ p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
+ wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
+ wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
+
+ return wrapped_img_embeds, wrapped_atts_img
+ else:
+ return img_embeds, atts_img
+
+ def forward(self, samples):
+ if 'conv_type' in samples.keys() and samples['conv_type']=='multi':
+ num_patch_tokens = self.num_video_query_token
+ im_patch_token_id = self.IMAGE_PATCH_TOKEN_ID
+ image = samples["images"]
+ input_ids = samples['input_ids']
+ if len(image.size())==4:
+ time = 1
+ image = einops.repeat(image, 'b c h w -> b c t h w',t = time)
+ img_embeds, atts_img = self.encode_img(image)
+
+ temp_input_ids = copy.deepcopy(input_ids)
+ temp_input_ids[temp_input_ids == im_patch_token_id] = 0
+ temp_input_embedding = self.llama_model.model.embed_tokens(temp_input_ids)
+
+ new_input_embeds=[]
+ cur_image_idx = 0
+ for cur_input_ids, cur_input_embeds in zip(input_ids, temp_input_embedding):
+ cur_image_features = img_embeds[cur_image_idx]
+
+ if (cur_input_ids == im_patch_token_id).sum() != num_patch_tokens:
+ raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
+ masked_indices = torch.where(cur_input_ids == im_patch_token_id)[0]
+ mask_index_start = masked_indices[0]
+ if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patch_tokens, device=masked_indices.device, dtype=masked_indices.dtype)).any():
+ raise ValueError("The image patch tokens should be consecutive.")
+
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patch_tokens:]), dim=0)
+ new_input_embeds.append(cur_new_input_embeds)
+
+ cur_image_idx+=1
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
+ targets = samples['labels']
+ attention_mask = samples['attention_mask']
+ with self.maybe_autocast():
+ outputs = self.llama_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ return_dict=True,
+ labels=targets,
+ )
+ loss = outputs.loss
+ return {"loss": loss}
+ else:
+ image = samples["image"]
+
+ if len(image.size()) != 5:
+ time = 1
+ image = einops.repeat(image, 'b c h w -> b c t h w',t = time)
+
+ img_embeds, atts_img = self.encode_img(image)
+
+ if self.prompt_list:
+ prompt = random.choice(self.prompt_list)
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
+
+
+ self.llama_tokenizer.padding_side = "right"
+
+ text = [t + self.end_sym for t in samples["text_input"]]
+
+ to_regress_tokens = self.llama_tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=self.max_txt_len,
+ add_special_tokens=False
+ ).to(image.device)
+
+ targets = to_regress_tokens.input_ids.masked_fill(
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
+ )
+
+ empty_targets = (
+ torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
+ dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
+ )
+ targets = torch.cat([empty_targets, targets], dim=1)
+
+ batch_size = img_embeds.shape[0]
+ bos = torch.ones([batch_size, 1],
+ dtype=to_regress_tokens.input_ids.dtype,
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
+ atts_bos = atts_img[:, :1]
+
+ to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
+ inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
+ attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
+
+ with self.maybe_autocast():
+ outputs = self.llama_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ return_dict=True,
+ labels=targets,
+ )
+ loss = outputs.loss
+
+ return {"loss": loss}
+
+ @classmethod
+ def from_config(cls, cfg):
+ vit_model = cfg.get("vit_model", "eva_clip_g")
+ q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
+ img_size = cfg.get("image_size")
+ num_query_token = cfg.get("num_query_token")
+ llama_model = cfg.get("llama_model")
+
+ drop_path_rate = cfg.get("drop_path_rate", 0)
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+ vit_precision = cfg.get("vit_precision", "fp16")
+ freeze_vit = cfg.get("freeze_vit", True)
+ freeze_qformer = cfg.get("freeze_qformer", True)
+ low_resource = cfg.get("low_resource", False)
+ device_8bit = cfg.get("device_8bit", 0)
+
+ prompt_path = cfg.get("prompt_path", "")
+ prompt_template = cfg.get("prompt_template", "")
+ max_txt_len = cfg.get("max_txt_len", 32)
+ end_sym = cfg.get("end_sym", '\n')
+
+ frozen_llama_proj = cfg.get("frozen_llama_proj", True)
+ llama_proj_model = cfg.get("llama_proj_model", '')
+
+ fusion_header_type = cfg.get("fusion_header_type", 'seqTransf')
+ max_frame_pos = cfg.get("max_frame_pos", 32)
+ fusion_head_layers = cfg.get("fusion_head_layers", 2)
+ num_video_query_token = cfg.get("num_video_query_token", 32)
+
+ model = cls(
+ vit_model=vit_model,
+ q_former_model=q_former_model,
+ img_size=img_size,
+ drop_path_rate=drop_path_rate,
+ use_grad_checkpoint=use_grad_checkpoint,
+ vit_precision=vit_precision,
+ freeze_vit=freeze_vit,
+ freeze_qformer=freeze_qformer,
+ num_query_token=num_query_token,
+ llama_model=llama_model,
+ prompt_path=prompt_path,
+ prompt_template=prompt_template,
+ max_txt_len=max_txt_len,
+ end_sym=end_sym,
+ low_resource=low_resource,
+ device_8bit=device_8bit,
+ fusion_header_type=fusion_header_type,
+ max_frame_pos=max_frame_pos,
+ fusion_head_layers=fusion_head_layers,
+ frozen_llama_proj=frozen_llama_proj,
+ num_video_query_token=num_video_query_token
+ )
+
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
+ if ckpt_path:
+ print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+ msg = model.load_state_dict(ckpt['model'], strict=False)
+ return model
diff --git a/video_llama/processors/__init__.py b/video_llama/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..169237f3dd45dba53cf77f40c8a69e835d0bcecc
--- /dev/null
+++ b/video_llama/processors/__init__.py
@@ -0,0 +1,38 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from video_llama.processors.base_processor import BaseProcessor
+from video_llama.processors.blip_processors import (
+ Blip2ImageTrainProcessor,
+ Blip2ImageEvalProcessor,
+ BlipCaptionProcessor,
+)
+from video_llama.processors.video_processor import (
+ AlproVideoTrainProcessor,
+ AlproVideoEvalProcessor
+)
+from video_llama.common.registry import registry
+
+__all__ = [
+ "BaseProcessor",
+ "Blip2ImageTrainProcessor",
+ "Blip2ImageEvalProcessor",
+ "BlipCaptionProcessor",
+ "AlproVideoTrainProcessor",
+ "AlproVideoEvalProcessor",
+]
+
+
+def load_processor(name, cfg=None):
+ """
+ Example
+
+ >>> processor = load_processor("alpro_video_train", cfg=None)
+ """
+ processor = registry.get_processor_class(name).from_config(cfg)
+
+ return processor
diff --git a/video_llama/processors/base_processor.py b/video_llama/processors/base_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..39b33cdf8fcd97cfd3e4a5fbece6593357af9d41
--- /dev/null
+++ b/video_llama/processors/base_processor.py
@@ -0,0 +1,26 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from omegaconf import OmegaConf
+
+
+class BaseProcessor:
+ def __init__(self):
+ self.transform = lambda x: x
+ return
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ return cls()
+
+ def build(self, **kwargs):
+ cfg = OmegaConf.create(kwargs)
+
+ return self.from_config(cfg)
diff --git a/video_llama/processors/blip_processors.py b/video_llama/processors/blip_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e603b638607921440bf7c4fcf22c5f1aeb7f20d
--- /dev/null
+++ b/video_llama/processors/blip_processors.py
@@ -0,0 +1,142 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import re
+
+from video_llama.common.registry import registry
+from video_llama.processors.base_processor import BaseProcessor
+from video_llama.processors.randaugment import RandomAugment
+from omegaconf import OmegaConf
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+
+class BlipImageBaseProcessor(BaseProcessor):
+ def __init__(self, mean=None, std=None):
+ if mean is None:
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ if std is None:
+ std = (0.26862954, 0.26130258, 0.27577711)
+
+ self.normalize = transforms.Normalize(mean, std)
+
+
+@registry.register_processor("blip_caption")
+class BlipCaptionProcessor(BaseProcessor):
+ def __init__(self, prompt="", max_words=50):
+ self.prompt = prompt
+ self.max_words = max_words
+
+ def __call__(self, caption):
+ caption = self.prompt + self.pre_caption(caption)
+
+ return caption
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ prompt = cfg.get("prompt", "")
+ max_words = cfg.get("max_words", 50)
+
+ return cls(prompt=prompt, max_words=max_words)
+
+ def pre_caption(self, caption):
+ caption = re.sub(
+ r"([.!\"()*#:;~])",
+ " ",
+ caption.lower(),
+ )
+ caption = re.sub(
+ r"\s{2,}",
+ " ",
+ caption,
+ )
+ caption = caption.rstrip("\n")
+ caption = caption.strip(" ")
+
+ # truncate caption
+ caption_words = caption.split(" ")
+ if len(caption_words) > self.max_words:
+ caption = " ".join(caption_words[: self.max_words])
+
+ return caption
+
+
+@registry.register_processor("blip2_image_train")
+class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
+ super().__init__(mean=mean, std=std)
+
+ self.transform = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ image_size,
+ scale=(min_scale, max_scale),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.ToTensor(),
+ self.normalize,
+ ]
+ )
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 224)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ min_scale = cfg.get("min_scale", 0.5)
+ max_scale = cfg.get("max_scale", 1.0)
+
+ return cls(
+ image_size=image_size,
+ mean=mean,
+ std=std,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ )
+
+
+@registry.register_processor("blip2_image_eval")
+class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
+ def __init__(self, image_size=224, mean=None, std=None):
+ super().__init__(mean=mean, std=std)
+
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
+ ),
+ transforms.ToTensor(),
+ self.normalize,
+ ]
+ )
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 224)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ return cls(image_size=image_size, mean=mean, std=std)
+
diff --git a/video_llama/processors/functional_video.py b/video_llama/processors/functional_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..597a29315d4e1a575e7209edb0618eeaf4fc024a
--- /dev/null
+++ b/video_llama/processors/functional_video.py
@@ -0,0 +1,121 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import warnings
+
+import torch
+
+
+def _is_tensor_video_clip(clip):
+ if not torch.is_tensor(clip):
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
+
+ if not clip.ndimension() == 4:
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
+
+ return True
+
+
+def crop(clip, i, j, h, w):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ """
+ if len(clip.size()) != 4:
+ raise ValueError("clip should be a 4D tensor")
+ return clip[..., i : i + h, j : j + w]
+
+
+def resize(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(
+ f"target size should be tuple (height, width), instead got {target_size}"
+ )
+ return torch.nn.functional.interpolate(
+ clip, size=target_size, mode=interpolation_mode, align_corners=False
+ )
+
+
+def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
+ """
+ Do spatial cropping and resizing to the video clip
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
+ h (int): Height of the cropped region.
+ w (int): Width of the cropped region.
+ size (tuple(int, int)): height and width of resized clip
+ Returns:
+ clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ clip = crop(clip, i, j, h, w)
+ clip = resize(clip, size, interpolation_mode)
+ return clip
+
+
+def center_crop(clip, crop_size):
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+ th, tw = crop_size
+ if h < th or w < tw:
+ raise ValueError("height and width must be no smaller than crop_size")
+
+ i = int(round((h - th) / 2.0))
+ j = int(round((w - tw) / 2.0))
+ return crop(clip, i, j, th, tw)
+
+
+def to_tensor(clip):
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
+ """
+ _is_tensor_video_clip(clip)
+ if not clip.dtype == torch.uint8:
+ raise TypeError(
+ "clip tensor should have data type uint8. Got %s" % str(clip.dtype)
+ )
+ return clip.float().permute(3, 0, 1, 2) / 255.0
+
+
+def normalize(clip, mean, std, inplace=False):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
+ mean (tuple): pixel RGB mean. Size is (3)
+ std (tuple): pixel standard deviation. Size is (3)
+ Returns:
+ normalized clip (torch.tensor): Size is (C, T, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ if not inplace:
+ clip = clip.clone()
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
+ return clip
+
+
+def hflip(clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
+ Returns:
+ flipped clip (torch.tensor): Size is (C, T, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ return clip.flip(-1)
diff --git a/video_llama/processors/randaugment.py b/video_llama/processors/randaugment.py
new file mode 100644
index 0000000000000000000000000000000000000000..7034a49ad5fc63b97910790017432617ff4c6d7b
--- /dev/null
+++ b/video_llama/processors/randaugment.py
@@ -0,0 +1,398 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import cv2
+import numpy as np
+
+import torch
+
+
+## aug functions
+def identity_func(img):
+ return img
+
+
+def autocontrast_func(img, cutoff=0):
+ """
+ same output as PIL.ImageOps.autocontrast
+ """
+ n_bins = 256
+
+ def tune_channel(ch):
+ n = ch.size
+ cut = cutoff * n // 100
+ if cut == 0:
+ high, low = ch.max(), ch.min()
+ else:
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ low = np.argwhere(np.cumsum(hist) > cut)
+ low = 0 if low.shape[0] == 0 else low[0]
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
+ if high <= low:
+ table = np.arange(n_bins)
+ else:
+ scale = (n_bins - 1) / (high - low)
+ offset = -low * scale
+ table = np.arange(n_bins) * scale + offset
+ table[table < 0] = 0
+ table[table > n_bins - 1] = n_bins - 1
+ table = table.clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def equalize_func(img):
+ """
+ same output as PIL.ImageOps.equalize
+ PIL's implementation is different from cv2.equalize
+ """
+ n_bins = 256
+
+ def tune_channel(ch):
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ non_zero_hist = hist[hist != 0].reshape(-1)
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
+ if step == 0:
+ return ch
+ n = np.empty_like(hist)
+ n[0] = step // 2
+ n[1:] = hist[:-1]
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def rotate_func(img, degree, fill=(0, 0, 0)):
+ """
+ like PIL, rotate by degree, not radians
+ """
+ H, W = img.shape[0], img.shape[1]
+ center = W / 2, H / 2
+ M = cv2.getRotationMatrix2D(center, degree, 1)
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
+ return out
+
+
+def solarize_func(img, thresh=128):
+ """
+ same output as PIL.ImageOps.posterize
+ """
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
+ table = table.clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def color_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Color
+ """
+ ## implementation according to PIL definition, quite slow
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
+ # out = blend(degenerate, img, factor)
+ # M = (
+ # np.eye(3) * factor
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
+ # )[np.newaxis, np.newaxis, :]
+ M = np.float32(
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
+ return out
+
+
+def contrast_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Contrast
+ """
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
+ table = (
+ np.array([(el - mean) * factor + mean for el in range(256)])
+ .clip(0, 255)
+ .astype(np.uint8)
+ )
+ out = table[img]
+ return out
+
+
+def brightness_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Contrast
+ """
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def sharpness_func(img, factor):
+ """
+ The differences the this result and PIL are all on the 4 boundaries, the center
+ areas are same
+ """
+ kernel = np.ones((3, 3), dtype=np.float32)
+ kernel[1][1] = 5
+ kernel /= 13
+ degenerate = cv2.filter2D(img, -1, kernel)
+ if factor == 0.0:
+ out = degenerate
+ elif factor == 1.0:
+ out = img
+ else:
+ out = img.astype(np.float32)
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
+ out = out.astype(np.uint8)
+ return out
+
+
+def shear_x_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def translate_x_func(img, offset, fill=(0, 0, 0)):
+ """
+ same output as PIL.Image.transform
+ """
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def translate_y_func(img, offset, fill=(0, 0, 0)):
+ """
+ same output as PIL.Image.transform
+ """
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def posterize_func(img, bits):
+ """
+ same output as PIL.ImageOps.posterize
+ """
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
+ return out
+
+
+def shear_y_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def cutout_func(img, pad_size, replace=(0, 0, 0)):
+ replace = np.array(replace, dtype=np.uint8)
+ H, W = img.shape[0], img.shape[1]
+ rh, rw = np.random.random(2)
+ pad_size = pad_size // 2
+ ch, cw = int(rh * H), int(rw * W)
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
+ out = img.copy()
+ out[x1:x2, y1:y2, :] = replace
+ return out
+
+
+### level to args
+def enhance_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
+
+ return level_to_args
+
+
+def shear_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 0.3
+ if np.random.random() > 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * float(translate_const)
+ if np.random.random() > 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * cutout_const)
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def solarize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 256)
+ return (level,)
+
+ return level_to_args
+
+
+def none_level_to_args(level):
+ return ()
+
+
+def posterize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 4)
+ return (level,)
+
+ return level_to_args
+
+
+def rotate_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 30
+ if np.random.random() < 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+func_dict = {
+ "Identity": identity_func,
+ "AutoContrast": autocontrast_func,
+ "Equalize": equalize_func,
+ "Rotate": rotate_func,
+ "Solarize": solarize_func,
+ "Color": color_func,
+ "Contrast": contrast_func,
+ "Brightness": brightness_func,
+ "Sharpness": sharpness_func,
+ "ShearX": shear_x_func,
+ "TranslateX": translate_x_func,
+ "TranslateY": translate_y_func,
+ "Posterize": posterize_func,
+ "ShearY": shear_y_func,
+}
+
+translate_const = 10
+MAX_LEVEL = 10
+replace_value = (128, 128, 128)
+arg_dict = {
+ "Identity": none_level_to_args,
+ "AutoContrast": none_level_to_args,
+ "Equalize": none_level_to_args,
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
+ "Color": enhance_level_to_args(MAX_LEVEL),
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
+}
+
+
+class RandomAugment(object):
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
+ self.N = N
+ self.M = M
+ self.isPIL = isPIL
+ if augs:
+ self.augs = augs
+ else:
+ self.augs = list(arg_dict.keys())
+
+ def get_random_ops(self):
+ sampled_ops = np.random.choice(self.augs, self.N)
+ return [(op, 0.5, self.M) for op in sampled_ops]
+
+ def __call__(self, img):
+ if self.isPIL:
+ img = np.array(img)
+ ops = self.get_random_ops()
+ for name, prob, level in ops:
+ if np.random.random() > prob:
+ continue
+ args = arg_dict[name](level)
+ img = func_dict[name](img, *args)
+ return img
+
+
+class VideoRandomAugment(object):
+ def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
+ self.N = N
+ self.M = M
+ self.p = p
+ self.tensor_in_tensor_out = tensor_in_tensor_out
+ if augs:
+ self.augs = augs
+ else:
+ self.augs = list(arg_dict.keys())
+
+ def get_random_ops(self):
+ sampled_ops = np.random.choice(self.augs, self.N, replace=False)
+ return [(op, self.M) for op in sampled_ops]
+
+ def __call__(self, frames):
+ assert (
+ frames.shape[-1] == 3
+ ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
+
+ if self.tensor_in_tensor_out:
+ frames = frames.numpy().astype(np.uint8)
+
+ num_frames = frames.shape[0]
+
+ ops = num_frames * [self.get_random_ops()]
+ apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
+
+ frames = torch.stack(
+ list(map(self._aug, frames, ops, apply_or_not)), dim=0
+ ).float()
+
+ return frames
+
+ def _aug(self, img, ops, apply_or_not):
+ for i, (name, level) in enumerate(ops):
+ if not apply_or_not[i]:
+ continue
+ args = arg_dict[name](level)
+ img = func_dict[name](img, *args)
+ return torch.from_numpy(img)
+
+
+if __name__ == "__main__":
+ a = RandomAugment()
+ img = np.random.randn(32, 32, 3)
+ a(img)
diff --git a/video_llama/processors/transforms_video.py b/video_llama/processors/transforms_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..1106f388c4091f919e0e9602fcb1363e9caec9a6
--- /dev/null
+++ b/video_llama/processors/transforms_video.py
@@ -0,0 +1,179 @@
+#!/usr/bin/env python3
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+
+import numbers
+import random
+
+from torchvision.transforms import (
+ RandomCrop,
+ RandomResizedCrop,
+)
+
+import video_llama.processors.functional_video as F
+
+
+__all__ = [
+ "RandomCropVideo",
+ "RandomResizedCropVideo",
+ "CenterCropVideo",
+ "NormalizeVideo",
+ "ToTensorVideo",
+ "RandomHorizontalFlipVideo",
+]
+
+
+class RandomCropVideo(RandomCrop):
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ Returns:
+ torch.tensor: randomly cropped/resized video clip.
+ size is (C, T, OH, OW)
+ """
+ i, j, h, w = self.get_params(clip, self.size)
+ return F.crop(clip, i, j, h, w)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size})"
+
+
+class RandomResizedCropVideo(RandomResizedCrop):
+ def __init__(
+ self,
+ size,
+ scale=(0.08, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(
+ f"size should be tuple (height, width), instead got {size}"
+ )
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+ self.scale = scale
+ self.ratio = ratio
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ Returns:
+ torch.tensor: randomly cropped/resized video clip.
+ size is (C, T, H, W)
+ """
+ i, j, h, w = self.get_params(clip, self.scale, self.ratio)
+ return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
+
+
+class CenterCropVideo:
+ def __init__(self, crop_size):
+ if isinstance(crop_size, numbers.Number):
+ self.crop_size = (int(crop_size), int(crop_size))
+ else:
+ self.crop_size = crop_size
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ Returns:
+ torch.tensor: central cropping of video clip. Size is
+ (C, T, crop_size, crop_size)
+ """
+ return F.center_crop(clip, self.crop_size)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(crop_size={self.crop_size})"
+
+
+class NormalizeVideo:
+ """
+ Normalize the video clip by mean subtraction and division by standard deviation
+ Args:
+ mean (3-tuple): pixel RGB mean
+ std (3-tuple): pixel RGB standard deviation
+ inplace (boolean): whether do in-place normalization
+ """
+
+ def __init__(self, mean, std, inplace=False):
+ self.mean = mean
+ self.std = std
+ self.inplace = inplace
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
+ """
+ return F.normalize(clip, self.mean, self.std, self.inplace)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
+
+
+class ToTensorVideo:
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
+ """
+ return F.to_tensor(clip)
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__
+
+
+class RandomHorizontalFlipVideo:
+ """
+ Flip the video clip along the horizonal direction with a given probability
+ Args:
+ p (float): probability of the clip being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Size is (C, T, H, W)
+ Return:
+ clip (torch.tensor): Size is (C, T, H, W)
+ """
+ if random.random() < self.p:
+ clip = F.hflip(clip)
+ return clip
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(p={self.p})"
diff --git a/video_llama/processors/video_processor.py b/video_llama/processors/video_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b272e318230c544748818476bbe4caa1fc9f847d
--- /dev/null
+++ b/video_llama/processors/video_processor.py
@@ -0,0 +1,237 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import torch
+from video_llama.common.registry import registry
+from decord import VideoReader
+import decord
+import numpy as np
+from video_llama.processors import transforms_video
+from video_llama.processors.base_processor import BaseProcessor
+from video_llama.processors.randaugment import VideoRandomAugment
+from video_llama.processors import functional_video as F
+from omegaconf import OmegaConf
+from torchvision import transforms
+import random as rnd
+
+
+MAX_INT = registry.get("MAX_INT")
+decord.bridge.set_bridge("torch")
+
+def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform", return_msg = False):
+ decord.bridge.set_bridge("torch")
+ vr = VideoReader(uri=video_path, height=height, width=width)
+
+ vlen = len(vr)
+ start, end = 0, vlen
+
+ n_frms = min(n_frms, vlen)
+
+ if sampling == "uniform":
+ indices = np.arange(start, end, vlen / n_frms).astype(int).tolist()
+ elif sampling == "headtail":
+ indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))
+ indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))
+ indices = indices_h + indices_t
+ else:
+ raise NotImplementedError
+
+ # get_batch -> T, H, W, C
+ temp_frms = vr.get_batch(indices)
+ # print(type(temp_frms))
+ tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
+ frms = tensor_frms.permute(3, 0, 1, 2).float() # (C, T, H, W)
+
+ if not return_msg:
+ return frms
+
+ fps = float(vr.get_avg_fps())
+ sec = ", ".join([str(round(f / fps, 1)) for f in indices])
+ # " " should be added in the start and end
+ msg = f"The video contains {len(indices)} frames sampled at {sec} seconds. "
+ return frms, msg
+
+
+class AlproVideoBaseProcessor(BaseProcessor):
+ def __init__(self, mean=None, std=None, n_frms=MAX_INT):
+ if mean is None:
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ if std is None:
+ std = (0.26862954, 0.26130258, 0.27577711)
+
+ self.normalize = transforms_video.NormalizeVideo(mean, std)
+
+ self.n_frms = n_frms
+
+
+class ToUint8(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, tensor):
+ return tensor.to(torch.uint8)
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+
+class ToTHWC(object):
+ """
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C)
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, tensor):
+ return tensor.permute(1, 2, 3, 0)
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+
+class ResizeVideo(object):
+ def __init__(self, target_size, interpolation_mode="bilinear"):
+ self.target_size = target_size
+ self.interpolation_mode = interpolation_mode
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ Returns:
+ torch.tensor: central cropping of video clip. Size is
+ (C, T, crop_size, crop_size)
+ """
+ return F.resize(clip, self.target_size, self.interpolation_mode)
+
+ def __repr__(self):
+ return self.__class__.__name__ + "(resize_size={0})".format(self.target_size)
+
+
+@registry.register_processor("alpro_video_train")
+class AlproVideoTrainProcessor(AlproVideoBaseProcessor):
+ def __init__(
+ self,
+ image_size=384,
+ mean=None,
+ std=None,
+ min_scale=0.5,
+ max_scale=1.0,
+ n_frms=MAX_INT,
+ ):
+ super().__init__(mean=mean, std=std, n_frms=n_frms)
+
+ self.image_size = image_size
+
+ self.transform = transforms.Compose(
+ [
+ # Video size is (C, T, H, W)
+ transforms_video.RandomResizedCropVideo(
+ image_size,
+ scale=(min_scale, max_scale),
+ interpolation_mode="bicubic",
+ ),
+ ToTHWC(), # C, T, H, W -> T, H, W, C
+ ToUint8(),
+ transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W
+ self.normalize,
+ ]
+ )
+
+ def __call__(self, vpath):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ Returns:
+ torch.tensor: video clip after transforms. Size is (C, T, size, size).
+ """
+ clip = load_video(
+ video_path=vpath,
+ n_frms=self.n_frms,
+ height=self.image_size,
+ width=self.image_size,
+ sampling="headtail",
+ )
+
+ return self.transform(clip)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 256)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ min_scale = cfg.get("min_scale", 0.5)
+ max_scale = cfg.get("max_scale", 1.0)
+
+ n_frms = cfg.get("n_frms", MAX_INT)
+
+ return cls(
+ image_size=image_size,
+ mean=mean,
+ std=std,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ n_frms=n_frms,
+ )
+
+
+@registry.register_processor("alpro_video_eval")
+class AlproVideoEvalProcessor(AlproVideoBaseProcessor):
+ def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT):
+ super().__init__(mean=mean, std=std, n_frms=n_frms)
+
+ self.image_size = image_size
+
+ # Input video size is (C, T, H, W)
+ self.transform = transforms.Compose(
+ [
+ # frames will be resized during decord loading.
+ ToUint8(), # C, T, H, W
+ ToTHWC(), # T, H, W, C
+ transforms_video.ToTensorVideo(), # C, T, H, W
+ self.normalize, # C, T, H, W
+ ]
+ )
+
+ def __call__(self, vpath):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ Returns:
+ torch.tensor: video clip after transforms. Size is (C, T, size, size).
+ """
+ clip = load_video(
+ video_path=vpath,
+ n_frms=self.n_frms,
+ height=self.image_size,
+ width=self.image_size,
+ )
+
+ return self.transform(clip)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 256)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ n_frms = cfg.get("n_frms", MAX_INT)
+
+ return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms)
diff --git a/video_llama/runners/__init__.py b/video_llama/runners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ffe5b0b10e013fb6d69eb6879b1e42c06d5b447
--- /dev/null
+++ b/video_llama/runners/__init__.py
@@ -0,0 +1,10 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from video_llama.runners.runner_base import RunnerBase
+
+__all__ = ["RunnerBase"]
diff --git a/video_llama/runners/runner_base.py b/video_llama/runners/runner_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c944123917dd0bf9947f4204f9044538a0f8bf22
--- /dev/null
+++ b/video_llama/runners/runner_base.py
@@ -0,0 +1,658 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import json
+import logging
+import os
+import time
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+import webdataset as wds
+from video_llama.common.dist_utils import (
+ download_cached_file,
+ get_rank,
+ get_world_size,
+ is_main_process,
+ main_process,
+)
+from video_llama.common.registry import registry
+from video_llama.common.utils import is_url
+from video_llama.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
+from video_llama.datasets.datasets.dataloader_utils import (
+ IterLoader,
+ MultiIterLoader,
+ PrefetchLoader,
+)
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data import DataLoader, DistributedSampler
+
+
+@registry.register_runner("runner_base")
+class RunnerBase:
+ """
+ A runner class to train and evaluate a model given a task and datasets.
+
+ The runner uses pytorch distributed data parallel by default. Future release
+ will support other distributed frameworks.
+ """
+
+ def __init__(self, cfg, task, model, datasets, job_id):
+ self.config = cfg
+ self.job_id = job_id
+
+ self.task = task
+ self.datasets = datasets
+
+ self._model = model
+
+ self._wrapped_model = None
+ self._device = None
+ self._optimizer = None
+ self._scaler = None
+ self._dataloaders = None
+ self._lr_sched = None
+
+ self.start_epoch = 0
+
+ # self.setup_seeds()
+ self.setup_output_dir()
+
+ @property
+ def device(self):
+ if self._device is None:
+ self._device = torch.device(self.config.run_cfg.device)
+
+ return self._device
+
+ @property
+ def use_distributed(self):
+ return self.config.run_cfg.distributed
+
+ @property
+ def model(self):
+ """
+ A property to get the DDP-wrapped model on the device.
+ """
+ # move model to device
+ if self._model.device != self.device:
+ self._model = self._model.to(self.device)
+
+ # distributed training wrapper
+ if self.use_distributed:
+ if self._wrapped_model is None:
+ self._wrapped_model = DDP(
+ self._model, device_ids=[self.config.run_cfg.gpu]
+ )
+ else:
+ self._wrapped_model = self._model
+
+ return self._wrapped_model
+
+ @property
+ def optimizer(self):
+ # TODO make optimizer class and configurations
+ if self._optimizer is None:
+ num_parameters = 0
+ p_wd, p_non_wd = [], []
+ for n, p in self.model.named_parameters():
+ if not p.requires_grad:
+ continue # frozen weights
+ print(n)
+ if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
+ p_non_wd.append(p)
+ else:
+ p_wd.append(p)
+ num_parameters += p.data.nelement()
+ logging.info("number of trainable parameters: %d" % num_parameters)
+ optim_params = [
+ {
+ "params": p_wd,
+ "weight_decay": float(self.config.run_cfg.weight_decay),
+ },
+ {"params": p_non_wd, "weight_decay": 0},
+ ]
+ beta2 = self.config.run_cfg.get("beta2", 0.999)
+ self._optimizer = torch.optim.AdamW(
+ optim_params,
+ lr=float(self.config.run_cfg.init_lr),
+ weight_decay=float(self.config.run_cfg.weight_decay),
+ betas=(0.9, beta2),
+ )
+
+ return self._optimizer
+
+ @property
+ def scaler(self):
+ amp = self.config.run_cfg.get("amp", False)
+
+ if amp:
+ if self._scaler is None:
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ return self._scaler
+
+ @property
+ def lr_scheduler(self):
+ """
+ A property to get and create learning rate scheduler by split just in need.
+ """
+ if self._lr_sched is None:
+ lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
+
+ # max_epoch = self.config.run_cfg.max_epoch
+ max_epoch = self.max_epoch
+ # min_lr = self.config.run_cfg.min_lr
+ min_lr = self.min_lr
+ # init_lr = self.config.run_cfg.init_lr
+ init_lr = self.init_lr
+
+ # optional parameters
+ decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
+ warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
+ warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
+ iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None)
+
+ if iters_per_epoch is None:
+ try:
+ iters_per_epoch = len(self.dataloaders['train'])
+ except (AttributeError, TypeError):
+ iters_per_epoch = 10000
+
+ self._lr_sched = lr_sched_cls(
+ optimizer=self.optimizer,
+ max_epoch=max_epoch,
+ iters_per_epoch=iters_per_epoch,
+ min_lr=min_lr,
+ init_lr=init_lr,
+ decay_rate=decay_rate,
+ warmup_start_lr=warmup_start_lr,
+ warmup_steps=warmup_steps,
+ )
+
+ return self._lr_sched
+
+ @property
+ def dataloaders(self) -> dict:
+ """
+ A property to get and create dataloaders by split just in need.
+
+ If no train_dataset_ratio is provided, concatenate map-style datasets and
+ chain wds.DataPipe datasets separately. Training set becomes a tuple
+ (ConcatDataset, ChainDataset), both are optional but at least one of them is
+ required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
+
+ If train_dataset_ratio is provided, create a MultiIterLoader to sample
+ each dataset by ratios during training.
+
+ Currently do not support multiple datasets for validation and test.
+
+ Returns:
+ dict: {split_name: (tuples of) dataloader}
+ """
+ if self._dataloaders is None:
+
+ # concatenate map-style datasets and chain wds.DataPipe datasets separately
+ # training set becomes a tuple (ConcatDataset, ChainDataset), both are
+ # optional but at least one of them is required. The resultant ConcatDataset
+ # and ChainDataset will be sampled evenly.
+ logging.info(
+ "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
+ )
+
+ datasets = reorg_datasets_by_split(self.datasets)
+ self.datasets = datasets
+ # self.datasets = concat_datasets(datasets)
+
+ # print dataset statistics after concatenation/chaining
+ for split_name in self.datasets:
+ if isinstance(self.datasets[split_name], tuple) or isinstance(
+ self.datasets[split_name], list
+ ):
+ # mixed wds.DataPipeline and torch.utils.data.Dataset
+ num_records = sum(
+ [
+ len(d)
+ if not type(d) in [wds.DataPipeline, ChainDataset]
+ else 0
+ for d in self.datasets[split_name]
+ ]
+ )
+
+ else:
+ if hasattr(self.datasets[split_name], "__len__"):
+ # a single map-style dataset
+ num_records = len(self.datasets[split_name])
+ else:
+ # a single wds.DataPipeline
+ num_records = -1
+ logging.info(
+ "Only a single wds.DataPipeline dataset, no __len__ attribute."
+ )
+
+ if num_records >= 0:
+ logging.info(
+ "Loaded {} records for {} split from the dataset.".format(
+ num_records, split_name
+ )
+ )
+
+ # create dataloaders
+ split_names = sorted(self.datasets.keys())
+
+ datasets = [self.datasets[split] for split in split_names]
+ is_trains = [split in self.train_splits for split in split_names]
+
+ batch_sizes = [
+ self.config.run_cfg.batch_size_train
+ if split == "train"
+ else self.config.run_cfg.batch_size_eval
+ for split in split_names
+ ]
+
+ collate_fns = []
+ for dataset in datasets:
+ if isinstance(dataset, tuple) or isinstance(dataset, list):
+ collate_fns.append([getattr(d, "collater", None) for d in dataset])
+ else:
+ collate_fns.append(getattr(dataset, "collater", None))
+
+ dataloaders = self.create_loaders(
+ datasets=datasets,
+ num_workers=self.config.run_cfg.num_workers,
+ batch_sizes=batch_sizes,
+ is_trains=is_trains,
+ collate_fns=collate_fns,
+ )
+
+ self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
+
+ return self._dataloaders
+
+ @property
+ def cuda_enabled(self):
+ return self.device.type == "cuda"
+
+ @property
+ def max_epoch(self):
+ return int(self.config.run_cfg.max_epoch)
+
+ @property
+ def log_freq(self):
+ log_freq = self.config.run_cfg.get("log_freq", 50)
+ return int(log_freq)
+
+ @property
+ def init_lr(self):
+ return float(self.config.run_cfg.init_lr)
+
+ @property
+ def min_lr(self):
+ return float(self.config.run_cfg.min_lr)
+
+ @property
+ def accum_grad_iters(self):
+ return int(self.config.run_cfg.get("accum_grad_iters", 1))
+
+ @property
+ def valid_splits(self):
+ valid_splits = self.config.run_cfg.get("valid_splits", [])
+
+ if len(valid_splits) == 0:
+ logging.info("No validation splits found.")
+
+ return valid_splits
+
+ @property
+ def test_splits(self):
+ test_splits = self.config.run_cfg.get("test_splits", [])
+
+ return test_splits
+
+ @property
+ def train_splits(self):
+ train_splits = self.config.run_cfg.get("train_splits", [])
+
+ if len(train_splits) == 0:
+ logging.info("Empty train splits.")
+
+ return train_splits
+
+ @property
+ def evaluate_only(self):
+ """
+ Set to True to skip training.
+ """
+ return self.config.run_cfg.evaluate
+
+ @property
+ def use_dist_eval_sampler(self):
+ return self.config.run_cfg.get("use_dist_eval_sampler", True)
+
+ @property
+ def resume_ckpt_path(self):
+ return self.config.run_cfg.get("resume_ckpt_path", None)
+
+ @property
+ def train_loader(self):
+ train_dataloader = self.dataloaders["train"]
+
+ return train_dataloader
+
+ def setup_output_dir(self):
+ lib_root = Path(registry.get_path("library_root"))
+
+ output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
+ result_dir = output_dir / "result"
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ result_dir.mkdir(parents=True, exist_ok=True)
+
+ registry.register_path("result_dir", str(result_dir))
+ registry.register_path("output_dir", str(output_dir))
+
+ self.result_dir = result_dir
+ self.output_dir = output_dir
+
+ def train(self):
+ start_time = time.time()
+ best_agg_metric = 0
+ best_epoch = 0
+
+ self.log_config()
+
+ # resume from checkpoint if specified
+ if not self.evaluate_only and self.resume_ckpt_path is not None:
+ self._load_checkpoint(self.resume_ckpt_path)
+
+ for cur_epoch in range(self.start_epoch, self.max_epoch):
+ # training phase
+ if not self.evaluate_only:
+ logging.info("Start training")
+ train_stats = self.train_epoch(cur_epoch)
+ self.log_stats(split_name="train", stats=train_stats)
+
+ # evaluation phase
+ if len(self.valid_splits) > 0:
+ for split_name in self.valid_splits:
+ logging.info("Evaluating on {}.".format(split_name))
+
+ val_log = self.eval_epoch(
+ split_name=split_name, cur_epoch=cur_epoch
+ )
+ if val_log is not None:
+ if is_main_process():
+ assert (
+ "agg_metrics" in val_log
+ ), "No agg_metrics found in validation log."
+
+ agg_metrics = val_log["agg_metrics"]
+ if agg_metrics > best_agg_metric and split_name == "val":
+ best_epoch, best_agg_metric = cur_epoch, agg_metrics
+
+ self._save_checkpoint(cur_epoch, is_best=True)
+
+ val_log.update({"best_epoch": best_epoch})
+ self.log_stats(val_log, split_name)
+
+ else:
+ # if no validation split is provided, we just save the checkpoint at the end of each epoch.
+ if not self.evaluate_only:
+ self._save_checkpoint(cur_epoch, is_best=False)
+
+ if self.evaluate_only:
+ break
+
+ if self.config.run_cfg.distributed:
+ dist.barrier()
+
+ # testing phase
+ test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
+ self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logging.info("Training time {}".format(total_time_str))
+
+ def evaluate(self, cur_epoch="best", skip_reload=False):
+ test_logs = dict()
+
+ if len(self.test_splits) > 0:
+ for split_name in self.test_splits:
+ test_logs[split_name] = self.eval_epoch(
+ split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
+ )
+
+ return test_logs
+
+ def train_epoch(self, epoch):
+ # train
+ self.model.train()
+
+ return self.task.train_epoch(
+ epoch=epoch,
+ model=self.model,
+ data_loader=self.train_loader,
+ optimizer=self.optimizer,
+ scaler=self.scaler,
+ lr_scheduler=self.lr_scheduler,
+ cuda_enabled=self.cuda_enabled,
+ log_freq=self.log_freq,
+ accum_grad_iters=self.accum_grad_iters,
+ )
+
+ @torch.no_grad()
+ def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
+ """
+ Evaluate the model on a given split.
+
+ Args:
+ split_name (str): name of the split to evaluate on.
+ cur_epoch (int): current epoch.
+ skip_reload_best (bool): whether to skip reloading the best checkpoint.
+ During training, we will reload the best checkpoint for validation.
+ During testing, we will use provided weights and skip reloading the best checkpoint .
+ """
+ data_loader = self.dataloaders.get(split_name, None)
+ assert data_loader, "data_loader for split {} is None.".format(split_name)
+
+ # TODO In validation, you need to compute loss as well as metrics
+ # TODO consider moving to model.before_evaluation()
+ model = self.unwrap_dist_model(self.model)
+ if not skip_reload and cur_epoch == "best":
+ model = self._reload_best_model(model)
+ model.eval()
+
+ self.task.before_evaluation(
+ model=model,
+ dataset=self.datasets[split_name],
+ )
+ results = self.task.evaluation(model, data_loader)
+
+ if results is not None:
+ return self.task.after_evaluation(
+ val_result=results,
+ split_name=split_name,
+ epoch=cur_epoch,
+ )
+
+ def unwrap_dist_model(self, model):
+ if self.use_distributed:
+ return model.module
+ else:
+ return model
+
+ def create_loaders(
+ self,
+ datasets,
+ num_workers,
+ batch_sizes,
+ is_trains,
+ collate_fns,
+ dataset_ratios=None,
+ ):
+ """
+ Create dataloaders for training and validation.
+ """
+
+ def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
+ # create a single dataloader for each split
+ if isinstance(dataset, ChainDataset) or isinstance(
+ dataset, wds.DataPipeline
+ ):
+ # wds.WebdDataset instance are chained together
+ # webdataset.DataPipeline has its own sampler and collate_fn
+ loader = iter(
+ DataLoader(
+ dataset,
+ batch_size=bsz,
+ num_workers=num_workers,
+ pin_memory=True,
+ )
+ )
+ else:
+ # map-style dataset are concatenated together
+ # setup distributed sampler
+ if self.use_distributed:
+ sampler = DistributedSampler(
+ dataset,
+ shuffle=is_train,
+ num_replicas=get_world_size(),
+ rank=get_rank(),
+ )
+ if not self.use_dist_eval_sampler:
+ # e.g. retrieval evaluation
+ sampler = sampler if is_train else None
+ else:
+ sampler = None
+
+ loader = DataLoader(
+ dataset,
+ batch_size=bsz,
+ num_workers=num_workers,
+ pin_memory=True,
+ sampler=sampler,
+ shuffle=sampler is None and is_train,
+ collate_fn=collate_fn,
+ drop_last=True if is_train else False,
+ )
+ loader = PrefetchLoader(loader)
+
+ if is_train:
+ loader = IterLoader(loader, use_distributed=self.use_distributed)
+
+ return loader
+
+ loaders = []
+
+ for dataset, bsz, is_train, collate_fn in zip(
+ datasets, batch_sizes, is_trains, collate_fns
+ ):
+ if isinstance(dataset, list) or isinstance(dataset, tuple):
+ if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None:
+ dataset_ratios = [d.sample_ratio for d in dataset]
+ loader = MultiIterLoader(
+ loaders=[
+ _create_loader(d, num_workers, bsz, is_train, collate_fn[i])
+ for i, d in enumerate(dataset)
+ ],
+ ratios=dataset_ratios,
+ )
+ else:
+ loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
+
+ loaders.append(loader)
+
+ return loaders
+
+ @main_process
+ def _save_checkpoint(self, cur_epoch, is_best=False):
+ """
+ Save the checkpoint at the current epoch.
+ """
+ model_no_ddp = self.unwrap_dist_model(self.model)
+ param_grad_dic = {
+ k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
+ }
+ state_dict = model_no_ddp.state_dict()
+ for k in list(state_dict.keys()):
+ if k in param_grad_dic.keys() and not param_grad_dic[k]:
+ # delete parameters that do not require gradient
+ del state_dict[k]
+ save_obj = {
+ "model": state_dict,
+ "optimizer": self.optimizer.state_dict(),
+ "config": self.config.to_dict(),
+ "scaler": self.scaler.state_dict() if self.scaler else None,
+ "epoch": cur_epoch,
+ }
+ save_to = os.path.join(
+ self.output_dir,
+ "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
+ )
+ logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
+ torch.save(save_obj, save_to)
+
+ def _reload_best_model(self, model):
+ """
+ Load the best checkpoint for evaluation.
+ """
+ checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
+
+ logging.info("Loading checkpoint from {}.".format(checkpoint_path))
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ try:
+ model.load_state_dict(checkpoint["model"])
+ except RuntimeError as e:
+ logging.warning(
+ """
+ Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
+ Trying to load the model with strict=False.
+ """
+ )
+ model.load_state_dict(checkpoint["model"], strict=False)
+ return model
+
+ def _load_checkpoint(self, url_or_filename):
+ """
+ Resume from a checkpoint.
+ """
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location=self.device, strict=False)
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False)
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ state_dict = checkpoint["model"]
+ self.unwrap_dist_model(self.model).load_state_dict(state_dict)
+
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ if self.scaler and "scaler" in checkpoint:
+ self.scaler.load_state_dict(checkpoint["scaler"])
+
+ self.start_epoch = checkpoint["epoch"] + 1
+ logging.info("Resume checkpoint from {}".format(url_or_filename))
+
+ @main_process
+ def log_stats(self, stats, split_name):
+ if isinstance(stats, dict):
+ log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ elif isinstance(stats, list):
+ pass
+
+ @main_process
+ def log_config(self):
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
+ f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")
diff --git a/video_llama/runners/test.py b/video_llama/runners/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/video_llama/tasks/__init__.py b/video_llama/tasks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..86c47d8aafe2637e13f3d837904a0f51dc96b379
--- /dev/null
+++ b/video_llama/tasks/__init__.py
@@ -0,0 +1,28 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from video_llama.common.registry import registry
+from video_llama.tasks.base_task import BaseTask
+from video_llama.tasks.image_text_pretrain import ImageTextPretrainTask
+from video_llama.tasks.video_text_pretrain import VideoTextPretrainTask
+
+
+def setup_task(cfg):
+ assert "task" in cfg.run_cfg, "Task name must be provided."
+
+ task_name = cfg.run_cfg.task
+ task = registry.get_task_class(task_name).setup_task(cfg=cfg)
+ assert task is not None, "Task {} not properly registered.".format(task_name)
+
+ return task
+
+
+__all__ = [
+ "BaseTask",
+ "ImageTextPretrainTask",
+ "VideoTextPretrainTask"
+]
diff --git a/video_llama/tasks/base_task.py b/video_llama/tasks/base_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..d91aa346df45ed58b7d2425a59dd21f3ae8ef93c
--- /dev/null
+++ b/video_llama/tasks/base_task.py
@@ -0,0 +1,286 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from video_llama.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
+from video_llama.common.logger import MetricLogger, SmoothedValue
+from video_llama.common.registry import registry
+from video_llama.datasets.data_utils import prepare_sample
+
+
+class BaseTask:
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ self.inst_id_key = "instance_id"
+
+ @classmethod
+ def setup_task(cls, **kwargs):
+ return cls()
+
+ def build_model(self, cfg):
+ model_config = cfg.model_cfg
+
+ model_cls = registry.get_model_class(model_config.arch)
+ return model_cls.from_config(model_config)
+
+ def build_datasets(self, cfg):
+ """
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
+ Download dataset and annotations automatically if not exist.
+
+ Args:
+ cfg (common.config.Config): _description_
+
+ Returns:
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
+ """
+
+ datasets = dict()
+
+ datasets_config = cfg.datasets_cfg
+
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
+
+ for name in datasets_config:
+ dataset_config = datasets_config[name]
+
+ builder = registry.get_builder_class(name)(dataset_config)
+ dataset = builder.build_datasets()
+
+ dataset['train'].name = name
+ if 'sample_ratio' in dataset_config:
+ dataset['train'].sample_ratio = dataset_config.sample_ratio
+
+ datasets[name] = dataset
+
+ return datasets
+
+ def train_step(self, model, samples):
+ loss = model(samples)["loss"]
+ return loss
+
+ def valid_step(self, model, samples):
+ raise NotImplementedError
+
+ def before_evaluation(self, model, dataset, **kwargs):
+ model.before_evaluation(dataset=dataset, task_type=type(self))
+
+ def after_evaluation(self, **kwargs):
+ pass
+
+ def inference_step(self):
+ raise NotImplementedError
+
+ def evaluation(self, model, data_loader, cuda_enabled=True):
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation"
+ # TODO make it configurable
+ print_freq = 10
+
+ results = []
+
+ for samples in metric_logger.log_every(data_loader, print_freq, header):
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
+
+ eval_output = self.valid_step(model=model, samples=samples)
+ results.extend(eval_output)
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ return results
+
+ def train_epoch(
+ self,
+ epoch,
+ model,
+ data_loader,
+ optimizer,
+ lr_scheduler,
+ scaler=None,
+ cuda_enabled=False,
+ log_freq=50,
+ accum_grad_iters=1,
+ ):
+ return self._train_inner_loop(
+ epoch=epoch,
+ iters_per_epoch=lr_scheduler.iters_per_epoch,
+ model=model,
+ data_loader=data_loader,
+ optimizer=optimizer,
+ scaler=scaler,
+ lr_scheduler=lr_scheduler,
+ log_freq=log_freq,
+ cuda_enabled=cuda_enabled,
+ accum_grad_iters=accum_grad_iters,
+ )
+
+ def train_iters(
+ self,
+ epoch,
+ start_iters,
+ iters_per_inner_epoch,
+ model,
+ data_loader,
+ optimizer,
+ lr_scheduler,
+ scaler=None,
+ cuda_enabled=False,
+ log_freq=50,
+ accum_grad_iters=1,
+ ):
+ return self._train_inner_loop(
+ epoch=epoch,
+ start_iters=start_iters,
+ iters_per_epoch=iters_per_inner_epoch,
+ model=model,
+ data_loader=data_loader,
+ optimizer=optimizer,
+ scaler=scaler,
+ lr_scheduler=lr_scheduler,
+ log_freq=log_freq,
+ cuda_enabled=cuda_enabled,
+ accum_grad_iters=accum_grad_iters,
+ )
+
+ def _train_inner_loop(
+ self,
+ epoch,
+ iters_per_epoch,
+ model,
+ data_loader,
+ optimizer,
+ lr_scheduler,
+ scaler=None,
+ start_iters=None,
+ log_freq=50,
+ cuda_enabled=False,
+ accum_grad_iters=1,
+ ):
+ """
+ An inner training loop compatible with both epoch-based and iter-based training.
+
+ When using epoch-based, training stops after one epoch; when using iter-based,
+ training stops after #iters_per_epoch iterations.
+ """
+ use_amp = scaler is not None
+
+ if not hasattr(data_loader, "__next__"):
+ # convert to iterator if not already
+ data_loader = iter(data_loader)
+
+ metric_logger = MetricLogger(delimiter=" ")
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
+ metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
+
+ # if iter-based runner, schedule lr based on inner epoch.
+ logging.info(
+ "Start training epoch {}, {} iters per inner epoch.".format(
+ epoch, iters_per_epoch
+ )
+ )
+ header = "Train: data epoch: [{}]".format(epoch)
+ if start_iters is None:
+ # epoch-based runner
+ inner_epoch = epoch
+ else:
+ # In iter-based runner, we schedule the learning rate based on iterations.
+ inner_epoch = start_iters // iters_per_epoch
+ header = header + "; inner epoch [{}]".format(inner_epoch)
+
+ for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
+ # if using iter-based runner, we stop after iters_per_epoch iterations.
+ if i >= iters_per_epoch:
+ break
+
+ samples = next(data_loader)
+
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
+ samples.update(
+ {
+ "epoch": inner_epoch,
+ "num_iters_per_epoch": iters_per_epoch,
+ "iters": i,
+ }
+ )
+
+ lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
+
+ with torch.cuda.amp.autocast(enabled=use_amp):
+ loss = self.train_step(model=model, samples=samples)
+
+ # after_train_step()
+ if use_amp:
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ # update gradients every accum_grad_iters iterations
+ if (i + 1) % accum_grad_iters == 0:
+ if use_amp:
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ optimizer.step()
+ optimizer.zero_grad()
+
+ metric_logger.update(loss=loss.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ # after train_epoch()
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ logging.info("Averaged stats: " + str(metric_logger.global_avg()))
+ return {
+ k: "{:.3f}".format(meter.global_avg)
+ for k, meter in metric_logger.meters.items()
+ }
+
+ @staticmethod
+ def save_result(result, result_dir, filename, remove_duplicate=""):
+ import json
+
+ result_file = os.path.join(
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
+ )
+ final_result_file = os.path.join(result_dir, "%s.json" % filename)
+
+ json.dump(result, open(result_file, "w"))
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ if is_main_process():
+ logging.warning("rank %d starts merging results." % get_rank())
+ # combine results from all processes
+ result = []
+
+ for rank in range(get_world_size()):
+ result_file = os.path.join(
+ result_dir, "%s_rank%d.json" % (filename, rank)
+ )
+ res = json.load(open(result_file, "r"))
+ result += res
+
+ if remove_duplicate:
+ result_new = []
+ id_list = []
+ for res in result:
+ if res[remove_duplicate] not in id_list:
+ id_list.append(res[remove_duplicate])
+ result_new.append(res)
+ result = result_new
+
+ json.dump(result, open(final_result_file, "w"))
+ print("result file saved to %s" % final_result_file)
+
+ return final_result_file
diff --git a/video_llama/tasks/image_text_pretrain.py b/video_llama/tasks/image_text_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..db955f27bb7dc8093cffd95b3a26917bb681c846
--- /dev/null
+++ b/video_llama/tasks/image_text_pretrain.py
@@ -0,0 +1,18 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from video_llama.common.registry import registry
+from video_llama.tasks.base_task import BaseTask
+
+
+@registry.register_task("image_text_pretrain")
+class ImageTextPretrainTask(BaseTask):
+ def __init__(self):
+ super().__init__()
+
+ def evaluation(self, model, data_loader, cuda_enabled=True):
+ pass
diff --git a/video_llama/tasks/video_text_pretrain.py b/video_llama/tasks/video_text_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf72e5e878500e0cc3aa719c7cb20b56f63c71a
--- /dev/null
+++ b/video_llama/tasks/video_text_pretrain.py
@@ -0,0 +1,18 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from video_llama.common.registry import registry
+from video_llama.tasks.base_task import BaseTask
+
+
+@registry.register_task("video_text_pretrain")
+class VideoTextPretrainTask(BaseTask):
+ def __init__(self):
+ super().__init__()
+
+ def evaluation(self, model, data_loader, cuda_enabled=True):
+ pass