diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7e48edc1fcbd340725978ff41afa3ebd192d1db5 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: Video LLaMA +emoji: 🚀 +colorFrom: purple +colorTo: gray +sdk: gradio +sdk_version: 3.29.0 +app_file: app.py +pinned: false +license: other +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..f06d76f82f514275bcf52ac69be076709b4845c5 --- /dev/null +++ b/app.py @@ -0,0 +1,192 @@ +""" +Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py +""" +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import gradio as gr + +from video_llama.common.config import Config +from video_llama.common.dist_utils import get_rank +from video_llama.common.registry import registry +from video_llama.conversation.conversation_video import Chat, Conversation, default_conversation,SeparatorStyle +import decord +decord.bridge.set_bridge('torch') + +#%% +# imports modules for registration +from video_llama.datasets.builders import * +from video_llama.models import * +from video_llama.processors import * +from video_llama.runners import * +from video_llama.tasks import * + +#%% +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", default='eval_configs/video_llama_eval.yaml', help="path to configuration file.") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +# ======================================== +# Model Initialization +# ======================================== + +print('Initializing Chat') +args = parse_args() +cfg = Config(args) + +model_config = cfg.model_cfg +model_config.device_8bit = args.gpu_id +model_cls = registry.get_model_class(model_config.arch) +model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) + +vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train +vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) +chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) +print('Initialization Finished') + +# ======================================== +# Gradio Setting +# ======================================== + +def gradio_reset(chat_state, img_list): + if chat_state is not None: + chat_state.messages = [] + if img_list is not None: + img_list = [] + return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list + +def upload_imgorvideo(gr_video, gr_img, text_input, chat_state): + if gr_img is None and gr_video is None: + return None, None, None, gr.update(interactive=True), chat_state, None + elif gr_img is not None and gr_video is None: + print(gr_img) + chat_state = Conversation( + system= "You are able to understand the visual content that the user provides." + "Follow the instructions carefully and explain your answers in detail.", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", + ) + img_list = [] + llm_message = chat.upload_img(gr_img, chat_state, img_list) + return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list + elif gr_video is not None and gr_img is None: + print(gr_video) + chat_state = default_conversation.copy() + chat_state = Conversation( + system= "You are able to understand the visual content that the user provides." + "Follow the instructions carefully and explain your answers in detail.", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", + ) + img_list = [] + llm_message = chat.upload_video(gr_video, chat_state, img_list) + return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list + else: + # img_list = [] + return gr.update(interactive=False), gr.update(interactive=False, placeholder='Currently, only one input is supported'), gr.update(value="Currently, only one input is supported", interactive=False), chat_state, None + +def gradio_ask(user_message, chatbot, chat_state): + if len(user_message) == 0: + return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state + chat.ask(user_message, chat_state) + chatbot = chatbot + [[user_message, None]] + return '', chatbot, chat_state + + +def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): + llm_message = chat.answer(conv=chat_state, + img_list=img_list, + num_beams=num_beams, + temperature=temperature, + max_new_tokens=300, + max_length=2000)[0] + chatbot[-1][1] = llm_message + print(chat_state.get_prompt()) + print(chat_state) + return chatbot, chat_state, img_list + +title = """

Demo of Video-LLaMA

""" +description = """

This is the demo of Video-LLaMA. Upload your images/videos and start chatting!

""" + + +#TODO show examples below + +with gr.Blocks() as demo: + gr.Markdown(title) + gr.Markdown(description) + + with gr.Row(): + with gr.Column(scale=0.5): + video = gr.Video() + image = gr.Image(type="pil") + + upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") + clear = gr.Button("Restart") + + num_beams = gr.Slider( + minimum=1, + maximum=10, + value=1, + step=1, + interactive=True, + label="beam search numbers)", + ) + + temperature = gr.Slider( + minimum=0.1, + maximum=2.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + ) + + with gr.Column(): + chat_state = gr.State() + img_list = gr.State() + chatbot = gr.Chatbot(label='Video-LLaMA') + text_input = gr.Textbox(label='User', placeholder='Please upload your image/video first', interactive=False) + + + upload_button.click(upload_imgorvideo, [video, image, text_input, chat_state], [video, image, text_input, upload_button, chat_state, img_list]) + + text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then( + gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] + ) + clear.click(gradio_reset, [chat_state, img_list], [chatbot, video, image, text_input, upload_button, chat_state, img_list], queue=False) + +demo.launch(share=False, enable_queue=False) + +# %% diff --git a/ckpt/blip2_pretrained_flant5xxl.pth b/ckpt/blip2_pretrained_flant5xxl.pth new file mode 100644 index 0000000000000000000000000000000000000000..70fa1bcf8e62b78c270d100025736323df52e52c --- /dev/null +++ b/ckpt/blip2_pretrained_flant5xxl.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b3839ea6c617f315ead9bf4036bbb0f0cf6bf62695ecfc14968ea626af03a29 +size 433481467 diff --git a/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth b/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth new file mode 100644 index 0000000000000000000000000000000000000000..9e48012e9ecebc4811525b863b539295d358971a --- /dev/null +++ b/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc4b32437c90df51bc3faa29deaa9b25ab77e1707ac79066f17ae3193ebe8bfc +size 1527692539 diff --git a/ckpt/finetune-vicuna7b-v2.pth b/ckpt/finetune-vicuna7b-v2.pth new file mode 100644 index 0000000000000000000000000000000000000000..ffcf94d3949e0d60d0c6d79392fbded8ea1aabe0 --- /dev/null +++ b/ckpt/finetune-vicuna7b-v2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0680ad8eb14c2a3273b7be71309ab6b06c9f426e87ad4675a903371fe0fa8162 +size 265436777 diff --git a/ckpt/pretrain-billa7b-zh.pth b/ckpt/pretrain-billa7b-zh.pth new file mode 100644 index 0000000000000000000000000000000000000000..e70ac5f43911ddecf2052cc1be40e5cce33f1c2f --- /dev/null +++ b/ckpt/pretrain-billa7b-zh.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f50a51db3055e1be6461f6dec833fbbbba28650287d26c8787664c8ee31dcf0f +size 265435689 diff --git a/eval_configs/video_llama_eval.yaml b/eval_configs/video_llama_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32be978a877ecdf08594f843871c769341b16903 --- /dev/null +++ b/eval_configs/video_llama_eval.yaml @@ -0,0 +1,32 @@ +model: + arch: video_llama + model_type: pretrain_vicuna + freeze_vit: True + freeze_qformer: True + max_txt_len: 512 + end_sym: "###" + low_resource: False + + + llama_model: "DAMO-NLP-SG/vicuna-7b" + + fusion_head_layers: 2 + max_frame_pos: 32 + fusion_header_type: "seqTransf" + + ckpt: 'ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth' + q_former_model: 'ckpt/blip2_pretrained_flant5xxl.pth' + +datasets: + webvid: + vis_processor: + train: + name: "alpro_video_eval" + n_frms: 8 + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: video_text_pretrain \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2d2ac49c79d36d9c4ca9262763499cf4bba2c81 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +transformers==4.28.0 +tqdm +decord +timm +einops +opencv_python +torchvision + +salesforce-lavis +bitsandbytes +accelerate \ No newline at end of file diff --git a/video_llama/__init__.py b/video_llama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90b38d4c6f7611646be527265cd0bd868d4f1f81 --- /dev/null +++ b/video_llama/__init__.py @@ -0,0 +1,31 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import sys + +from omegaconf import OmegaConf + +from video_llama.common.registry import registry + +from video_llama.datasets.builders import * +from video_llama.models import * +from video_llama.processors import * +from video_llama.tasks import * + + +root_dir = os.path.dirname(os.path.abspath(__file__)) +default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) + +registry.register_path("library_root", root_dir) +repo_root = os.path.join(root_dir, "..") +registry.register_path("repo_root", repo_root) +cache_root = os.path.join(repo_root, default_cfg.env.cache_root) +registry.register_path("cache_root", cache_root) + +registry.register("MAX_INT", sys.maxsize) +registry.register("SPLIT_NAMES", ["train", "val", "test"]) diff --git a/video_llama/app.py b/video_llama/app.py new file mode 100644 index 0000000000000000000000000000000000000000..f06d76f82f514275bcf52ac69be076709b4845c5 --- /dev/null +++ b/video_llama/app.py @@ -0,0 +1,192 @@ +""" +Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py +""" +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import gradio as gr + +from video_llama.common.config import Config +from video_llama.common.dist_utils import get_rank +from video_llama.common.registry import registry +from video_llama.conversation.conversation_video import Chat, Conversation, default_conversation,SeparatorStyle +import decord +decord.bridge.set_bridge('torch') + +#%% +# imports modules for registration +from video_llama.datasets.builders import * +from video_llama.models import * +from video_llama.processors import * +from video_llama.runners import * +from video_llama.tasks import * + +#%% +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", default='eval_configs/video_llama_eval.yaml', help="path to configuration file.") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +# ======================================== +# Model Initialization +# ======================================== + +print('Initializing Chat') +args = parse_args() +cfg = Config(args) + +model_config = cfg.model_cfg +model_config.device_8bit = args.gpu_id +model_cls = registry.get_model_class(model_config.arch) +model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) + +vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train +vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) +chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) +print('Initialization Finished') + +# ======================================== +# Gradio Setting +# ======================================== + +def gradio_reset(chat_state, img_list): + if chat_state is not None: + chat_state.messages = [] + if img_list is not None: + img_list = [] + return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list + +def upload_imgorvideo(gr_video, gr_img, text_input, chat_state): + if gr_img is None and gr_video is None: + return None, None, None, gr.update(interactive=True), chat_state, None + elif gr_img is not None and gr_video is None: + print(gr_img) + chat_state = Conversation( + system= "You are able to understand the visual content that the user provides." + "Follow the instructions carefully and explain your answers in detail.", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", + ) + img_list = [] + llm_message = chat.upload_img(gr_img, chat_state, img_list) + return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list + elif gr_video is not None and gr_img is None: + print(gr_video) + chat_state = default_conversation.copy() + chat_state = Conversation( + system= "You are able to understand the visual content that the user provides." + "Follow the instructions carefully and explain your answers in detail.", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", + ) + img_list = [] + llm_message = chat.upload_video(gr_video, chat_state, img_list) + return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list + else: + # img_list = [] + return gr.update(interactive=False), gr.update(interactive=False, placeholder='Currently, only one input is supported'), gr.update(value="Currently, only one input is supported", interactive=False), chat_state, None + +def gradio_ask(user_message, chatbot, chat_state): + if len(user_message) == 0: + return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state + chat.ask(user_message, chat_state) + chatbot = chatbot + [[user_message, None]] + return '', chatbot, chat_state + + +def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): + llm_message = chat.answer(conv=chat_state, + img_list=img_list, + num_beams=num_beams, + temperature=temperature, + max_new_tokens=300, + max_length=2000)[0] + chatbot[-1][1] = llm_message + print(chat_state.get_prompt()) + print(chat_state) + return chatbot, chat_state, img_list + +title = """

Demo of Video-LLaMA

""" +description = """

This is the demo of Video-LLaMA. Upload your images/videos and start chatting!

""" + + +#TODO show examples below + +with gr.Blocks() as demo: + gr.Markdown(title) + gr.Markdown(description) + + with gr.Row(): + with gr.Column(scale=0.5): + video = gr.Video() + image = gr.Image(type="pil") + + upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") + clear = gr.Button("Restart") + + num_beams = gr.Slider( + minimum=1, + maximum=10, + value=1, + step=1, + interactive=True, + label="beam search numbers)", + ) + + temperature = gr.Slider( + minimum=0.1, + maximum=2.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + ) + + with gr.Column(): + chat_state = gr.State() + img_list = gr.State() + chatbot = gr.Chatbot(label='Video-LLaMA') + text_input = gr.Textbox(label='User', placeholder='Please upload your image/video first', interactive=False) + + + upload_button.click(upload_imgorvideo, [video, image, text_input, chat_state], [video, image, text_input, upload_button, chat_state, img_list]) + + text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then( + gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] + ) + clear.click(gradio_reset, [chat_state, img_list], [chatbot, video, image, text_input, upload_button, chat_state, img_list], queue=False) + +demo.launch(share=False, enable_queue=False) + +# %% diff --git a/video_llama/ckpt/blip2_pretrained_flant5xxl.pth b/video_llama/ckpt/blip2_pretrained_flant5xxl.pth new file mode 100644 index 0000000000000000000000000000000000000000..70fa1bcf8e62b78c270d100025736323df52e52c --- /dev/null +++ b/video_llama/ckpt/blip2_pretrained_flant5xxl.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b3839ea6c617f315ead9bf4036bbb0f0cf6bf62695ecfc14968ea626af03a29 +size 433481467 diff --git a/video_llama/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth b/video_llama/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth new file mode 100644 index 0000000000000000000000000000000000000000..9c45fc8ea3c715c39e1d18de5ac6c63b22e9b224 --- /dev/null +++ b/video_llama/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46af76d307c14d28c56534e4bf8654343e5512aa1285fc1c1fdb5728c418e7ca +size 623104000 diff --git a/video_llama/ckpt/pretrain-billa7b-zh.pth b/video_llama/ckpt/pretrain-billa7b-zh.pth new file mode 100644 index 0000000000000000000000000000000000000000..e70ac5f43911ddecf2052cc1be40e5cce33f1c2f --- /dev/null +++ b/video_llama/ckpt/pretrain-billa7b-zh.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f50a51db3055e1be6461f6dec833fbbbba28650287d26c8787664c8ee31dcf0f +size 265435689 diff --git a/video_llama/common/__init__.py b/video_llama/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/video_llama/common/config.py b/video_llama/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a718644414610b2f10a01a51ebe58c14ff5f30 --- /dev/null +++ b/video_llama/common/config.py @@ -0,0 +1,468 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import json +from typing import Dict + +from omegaconf import OmegaConf +from video_llama.common.registry import registry + + +class Config: + def __init__(self, args): + self.config = {} + + self.args = args + + # Register the config and configuration for setup + registry.register("configuration", self) + + user_config = self._build_opt_list(self.args.options) + + config = OmegaConf.load(self.args.cfg_path) + + runner_config = self.build_runner_config(config) + model_config = self.build_model_config(config, **user_config) + dataset_config = self.build_dataset_config(config) + + # Validate the user-provided runner configuration + # model and dataset configuration are supposed to be validated by the respective classes + # [TODO] validate the model/dataset configuration + # self._validate_runner_config(runner_config) + + # Override the default configuration with user options. + self.config = OmegaConf.merge( + runner_config, model_config, dataset_config, user_config + ) + + def _validate_runner_config(self, runner_config): + """ + This method validates the configuration, such that + 1) all the user specified options are valid; + 2) no type mismatches between the user specified options and the config. + """ + runner_config_validator = create_runner_config_validator() + runner_config_validator.validate(runner_config) + + def _build_opt_list(self, opts): + opts_dot_list = self._convert_to_dot_list(opts) + return OmegaConf.from_dotlist(opts_dot_list) + + @staticmethod + def build_model_config(config, **kwargs): + model = config.get("model", None) + assert model is not None, "Missing model configuration file." + + model_cls = registry.get_model_class(model.arch) + assert model_cls is not None, f"Model '{model.arch}' has not been registered." + + model_type = kwargs.get("model.model_type", None) + if not model_type: + model_type = model.get("model_type", None) + # else use the model type selected by user. + + assert model_type is not None, "Missing model_type." + + model_config_path = model_cls.default_config_path(model_type=model_type) + + model_config = OmegaConf.create() + # hierarchy override, customized config > default config + model_config = OmegaConf.merge( + model_config, + OmegaConf.load(model_config_path), + {"model": config["model"]}, + ) + + return model_config + + @staticmethod + def build_runner_config(config): + return {"run": config.run} + + @staticmethod + def build_dataset_config(config): + datasets = config.get("datasets", None) + if datasets is None: + raise KeyError( + "Expecting 'datasets' as the root key for dataset configuration." + ) + + dataset_config = OmegaConf.create() + + for dataset_name in datasets: + builder_cls = registry.get_builder_class(dataset_name) + + dataset_config_type = datasets[dataset_name].get("type", "default") + dataset_config_path = builder_cls.default_config_path( + type=dataset_config_type + ) + + # hierarchy override, customized config > default config + dataset_config = OmegaConf.merge( + dataset_config, + OmegaConf.load(dataset_config_path), + {"datasets": {dataset_name: config["datasets"][dataset_name]}}, + ) + + return dataset_config + + def _convert_to_dot_list(self, opts): + if opts is None: + opts = [] + + if len(opts) == 0: + return opts + + has_equal = opts[0].find("=") != -1 + + if has_equal: + return opts + + return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] + + def get_config(self): + return self.config + + @property + def run_cfg(self): + return self.config.run + + @property + def datasets_cfg(self): + return self.config.datasets + + @property + def model_cfg(self): + return self.config.model + + def pretty_print(self): + logging.info("\n===== Running Parameters =====") + logging.info(self._convert_node_to_json(self.config.run)) + + logging.info("\n====== Dataset Attributes ======") + datasets = self.config.datasets + + for dataset in datasets: + if dataset in self.config.datasets: + logging.info(f"\n======== {dataset} =======") + dataset_config = self.config.datasets[dataset] + logging.info(self._convert_node_to_json(dataset_config)) + else: + logging.warning(f"No dataset named '{dataset}' in config. Skipping") + + logging.info(f"\n====== Model Attributes ======") + logging.info(self._convert_node_to_json(self.config.model)) + + def _convert_node_to_json(self, node): + container = OmegaConf.to_container(node, resolve=True) + return json.dumps(container, indent=4, sort_keys=True) + + def to_dict(self): + return OmegaConf.to_container(self.config) + + +def node_to_dict(node): + return OmegaConf.to_container(node) + + +class ConfigValidator: + """ + This is a preliminary implementation to centralize and validate the configuration. + May be altered in the future. + + A helper class to validate configurations from yaml file. + + This serves the following purposes: + 1. Ensure all the options in the yaml are defined, raise error if not. + 2. when type mismatches are found, the validator will raise an error. + 3. a central place to store and display helpful messages for supported configurations. + + """ + + class _Argument: + def __init__(self, name, choices=None, type=None, help=None): + self.name = name + self.val = None + self.choices = choices + self.type = type + self.help = help + + def __str__(self): + s = f"{self.name}={self.val}" + if self.type is not None: + s += f", ({self.type})" + if self.choices is not None: + s += f", choices: {self.choices}" + if self.help is not None: + s += f", ({self.help})" + return s + + def __init__(self, description): + self.description = description + + self.arguments = dict() + + self.parsed_args = None + + def __getitem__(self, key): + assert self.parsed_args is not None, "No arguments parsed yet." + + return self.parsed_args[key] + + def __str__(self) -> str: + return self.format_help() + + def add_argument(self, *args, **kwargs): + """ + Assume the first argument is the name of the argument. + """ + self.arguments[args[0]] = self._Argument(*args, **kwargs) + + def validate(self, config=None): + """ + Convert yaml config (dict-like) to list, required by argparse. + """ + for k, v in config.items(): + assert ( + k in self.arguments + ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" + + if self.arguments[k].type is not None: + try: + self.arguments[k].val = self.arguments[k].type(v) + except ValueError: + raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") + + if self.arguments[k].choices is not None: + assert ( + v in self.arguments[k].choices + ), f"""{k} must be one of {self.arguments[k].choices}.""" + + return config + + def format_arguments(self): + return str([f"{k}" for k in sorted(self.arguments.keys())]) + + def format_help(self): + # description + key-value pair string for each argument + help_msg = str(self.description) + return help_msg + ", available arguments: " + self.format_arguments() + + def print_help(self): + # display help message + print(self.format_help()) + + +def create_runner_config_validator(): + validator = ConfigValidator(description="Runner configurations") + + validator.add_argument( + "runner", + type=str, + choices=["runner_base", "runner_iter"], + help="""Runner to use. The "runner_base" uses epoch-based training while iter-based + runner runs based on iters. Default: runner_base""", + ) + # add argumetns for training dataset ratios + validator.add_argument( + "train_dataset_ratios", + type=Dict[str, float], + help="""Ratios of training dataset. This is used in iteration-based runner. + Do not support for epoch-based runner because how to define an epoch becomes tricky. + Default: None""", + ) + validator.add_argument( + "max_iters", + type=float, + help="Maximum number of iterations to run.", + ) + validator.add_argument( + "max_epoch", + type=int, + help="Maximum number of epochs to run.", + ) + # add arguments for iters_per_inner_epoch + validator.add_argument( + "iters_per_inner_epoch", + type=float, + help="Number of iterations per inner epoch. This is required when runner is runner_iter.", + ) + lr_scheds_choices = registry.list_lr_schedulers() + validator.add_argument( + "lr_sched", + type=str, + choices=lr_scheds_choices, + help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), + ) + task_choices = registry.list_tasks() + validator.add_argument( + "task", + type=str, + choices=task_choices, + help="Task to use, from {}".format(task_choices), + ) + # add arguments for init_lr + validator.add_argument( + "init_lr", + type=float, + help="Initial learning rate. This will be the learning rate after warmup and before decay.", + ) + # add arguments for min_lr + validator.add_argument( + "min_lr", + type=float, + help="Minimum learning rate (after decay).", + ) + # add arguments for warmup_lr + validator.add_argument( + "warmup_lr", + type=float, + help="Starting learning rate for warmup.", + ) + # add arguments for learning rate decay rate + validator.add_argument( + "lr_decay_rate", + type=float, + help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", + ) + # add arguments for weight decay + validator.add_argument( + "weight_decay", + type=float, + help="Weight decay rate.", + ) + # add arguments for training batch size + validator.add_argument( + "batch_size_train", + type=int, + help="Training batch size.", + ) + # add arguments for evaluation batch size + validator.add_argument( + "batch_size_eval", + type=int, + help="Evaluation batch size, including validation and testing.", + ) + # add arguments for number of workers for data loading + validator.add_argument( + "num_workers", + help="Number of workers for data loading.", + ) + # add arguments for warm up steps + validator.add_argument( + "warmup_steps", + type=int, + help="Number of warmup steps. Required if a warmup schedule is used.", + ) + # add arguments for random seed + validator.add_argument( + "seed", + type=int, + help="Random seed.", + ) + # add arguments for output directory + validator.add_argument( + "output_dir", + type=str, + help="Output directory to save checkpoints and logs.", + ) + # add arguments for whether only use evaluation + validator.add_argument( + "evaluate", + help="Whether to only evaluate the model. If true, training will not be performed.", + ) + # add arguments for splits used for training, e.g. ["train", "val"] + validator.add_argument( + "train_splits", + type=list, + help="Splits to use for training.", + ) + # add arguments for splits used for validation, e.g. ["val"] + validator.add_argument( + "valid_splits", + type=list, + help="Splits to use for validation. If not provided, will skip the validation.", + ) + # add arguments for splits used for testing, e.g. ["test"] + validator.add_argument( + "test_splits", + type=list, + help="Splits to use for testing. If not provided, will skip the testing.", + ) + # add arguments for accumulating gradient for iterations + validator.add_argument( + "accum_grad_iters", + type=int, + help="Number of iterations to accumulate gradient for.", + ) + + # ====== distributed training ====== + validator.add_argument( + "device", + type=str, + choices=["cpu", "cuda"], + help="Device to use. Support 'cuda' or 'cpu' as for now.", + ) + validator.add_argument( + "world_size", + type=int, + help="Number of processes participating in the job.", + ) + validator.add_argument("dist_url", type=str) + validator.add_argument("distributed", type=bool) + # add arguments to opt using distributed sampler during evaluation or not + validator.add_argument( + "use_dist_eval_sampler", + type=bool, + help="Whether to use distributed sampler during evaluation or not.", + ) + + # ====== task specific ====== + # generation task specific arguments + # add arguments for maximal length of text output + validator.add_argument( + "max_len", + type=int, + help="Maximal length of text output.", + ) + # add arguments for minimal length of text output + validator.add_argument( + "min_len", + type=int, + help="Minimal length of text output.", + ) + # add arguments number of beams + validator.add_argument( + "num_beams", + type=int, + help="Number of beams used for beam search.", + ) + + # vqa task specific arguments + # add arguments for number of answer candidates + validator.add_argument( + "num_ans_candidates", + type=int, + help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", + ) + # add arguments for inference method + validator.add_argument( + "inference_method", + type=str, + choices=["genearte", "rank"], + help="""Inference method to use for question answering. If rank, requires a answer list.""", + ) + + # ====== model specific ====== + validator.add_argument( + "k_test", + type=int, + help="Number of top k most similar samples from ITC/VTC selection to be tested.", + ) + + return validator diff --git a/video_llama/common/dist_utils.py b/video_llama/common/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9280150bf5122d51bb810a9f0258a233e7088647 --- /dev/null +++ b/video_llama/common/dist_utils.py @@ -0,0 +1,137 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import functools +import os + +import torch +import torch.distributed as dist +import timm.models.hub as timm_hub + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}, world {}): {}".format( + args.rank, args.world_size, args.dist_url + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + timeout=datetime.timedelta( + days=365 + ), # allow auto-downloading and de-compressing + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def get_dist_info(): + if torch.__version__ < "1.0": + initialized = dist._initialized + else: + initialized = dist.is_initialized() + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: # non-distributed training + rank = 0 + world_size = 1 + return rank, world_size + + +def main_process(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def download_cached_file(url, check_hash=True, progress=False): + """ + Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. + If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. + """ + + def get_cached_file_path(): + # a hack to sync the file path across processes + parts = torch.hub.urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(timm_hub.get_cache_dir(), filename) + + return cached_file + + if is_main_process(): + timm_hub.download_cached_file(url, check_hash, progress) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return get_cached_file_path() diff --git a/video_llama/common/gradcam.py b/video_llama/common/gradcam.py new file mode 100644 index 0000000000000000000000000000000000000000..d53a5254d4b319eaf2cbfbd081b0ca8e38c5c7a0 --- /dev/null +++ b/video_llama/common/gradcam.py @@ -0,0 +1,24 @@ +import numpy as np +from matplotlib import pyplot as plt +from scipy.ndimage import filters +from skimage import transform as skimage_transform + + +def getAttMap(img, attMap, blur=True, overlap=True): + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") + if blur: + attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) + attMap -= attMap.min() + attMap /= attMap.max() + cmap = plt.get_cmap("jet") + attMapV = cmap(attMap) + attMapV = np.delete(attMapV, 3, 2) + if overlap: + attMap = ( + 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img + + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV + ) + return attMap diff --git a/video_llama/common/logger.py b/video_llama/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e9de0e65de4601e9ea10f3af6372adabdd09e44c --- /dev/null +++ b/video_llama/common/logger.py @@ -0,0 +1,195 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + +from video_llama.common import dist_utils + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not dist_utils.is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def global_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def setup_logger(): + logging.basicConfig( + level=logging.INFO if dist_utils.is_main_process() else logging.WARN, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], + ) diff --git a/video_llama/common/optims.py b/video_llama/common/optims.py new file mode 100644 index 0000000000000000000000000000000000000000..b466e38dc4ceba80ea54759ba608b7281c583bed --- /dev/null +++ b/video_llama/common/optims.py @@ -0,0 +1,119 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import math + +from video_llama.common.registry import registry + + +@registry.register_lr_scheduler("linear_warmup_step_lr") +class LinearWarmupStepLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + min_lr, + init_lr, + decay_rate=1, + warmup_start_lr=-1, + warmup_steps=0, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.min_lr = min_lr + + self.decay_rate = decay_rate + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + if cur_epoch == 0: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + step_lr_schedule( + epoch=cur_epoch, + optimizer=self.optimizer, + init_lr=self.init_lr, + min_lr=self.min_lr, + decay_rate=self.decay_rate, + ) + + +@registry.register_lr_scheduler("linear_warmup_cosine_lr") +class LinearWarmupCosineLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + iters_per_epoch, + min_lr, + init_lr, + warmup_steps=0, + warmup_start_lr=-1, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.iters_per_epoch = iters_per_epoch + self.min_lr = min_lr + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + total_cur_step = cur_epoch * self.iters_per_epoch + cur_step + if total_cur_step < self.warmup_steps: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + cosine_lr_schedule( + epoch=total_cur_step, + optimizer=self.optimizer, + max_epoch=self.max_epoch * self.iters_per_epoch, + init_lr=self.init_lr, + min_lr=self.min_lr, + ) + + +def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): + """Decay the learning rate""" + lr = (init_lr - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * epoch / max_epoch) + ) + min_lr + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): + """Warmup the learning rate""" + lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): + """Decay the learning rate""" + lr = max(min_lr, init_lr * (decay_rate**epoch)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr diff --git a/video_llama/common/registry.py b/video_llama/common/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d0171ceb39123c7402a554eb8543ce55ff6881 --- /dev/null +++ b/video_llama/common/registry.py @@ -0,0 +1,329 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + + +class Registry: + mapping = { + "builder_name_mapping": {}, + "task_name_mapping": {}, + "processor_name_mapping": {}, + "model_name_mapping": {}, + "lr_scheduler_name_mapping": {}, + "runner_name_mapping": {}, + "state": {}, + "paths": {}, + } + + @classmethod + def register_builder(cls, name): + r"""Register a dataset builder to registry with key 'name' + + Args: + name: Key with which the builder will be registered. + + Usage: + + from video_llama.common.registry import registry + from video_llama.datasets.base_dataset_builder import BaseDatasetBuilder + """ + + def wrap(builder_cls): + from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder + + assert issubclass( + builder_cls, BaseDatasetBuilder + ), "All builders must inherit BaseDatasetBuilder class, found {}".format( + builder_cls + ) + if name in cls.mapping["builder_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["builder_name_mapping"][name] + ) + ) + cls.mapping["builder_name_mapping"][name] = builder_cls + return builder_cls + + return wrap + + @classmethod + def register_task(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from video_llama.common.registry import registry + """ + + def wrap(task_cls): + from video_llama.tasks.base_task import BaseTask + + assert issubclass( + task_cls, BaseTask + ), "All tasks must inherit BaseTask class" + if name in cls.mapping["task_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["task_name_mapping"][name] + ) + ) + cls.mapping["task_name_mapping"][name] = task_cls + return task_cls + + return wrap + + @classmethod + def register_model(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from video_llama.common.registry import registry + """ + + def wrap(model_cls): + from video_llama.models import BaseModel + + assert issubclass( + model_cls, BaseModel + ), "All models must inherit BaseModel class" + if name in cls.mapping["model_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["model_name_mapping"][name] + ) + ) + cls.mapping["model_name_mapping"][name] = model_cls + return model_cls + + return wrap + + @classmethod + def register_processor(cls, name): + r"""Register a processor to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from video_llama.common.registry import registry + """ + + def wrap(processor_cls): + from video_llama.processors import BaseProcessor + + assert issubclass( + processor_cls, BaseProcessor + ), "All processors must inherit BaseProcessor class" + if name in cls.mapping["processor_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["processor_name_mapping"][name] + ) + ) + cls.mapping["processor_name_mapping"][name] = processor_cls + return processor_cls + + return wrap + + @classmethod + def register_lr_scheduler(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from video_llama.common.registry import registry + """ + + def wrap(lr_sched_cls): + if name in cls.mapping["lr_scheduler_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["lr_scheduler_name_mapping"][name] + ) + ) + cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls + return lr_sched_cls + + return wrap + + @classmethod + def register_runner(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from video_llama.common.registry import registry + """ + + def wrap(runner_cls): + if name in cls.mapping["runner_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["runner_name_mapping"][name] + ) + ) + cls.mapping["runner_name_mapping"][name] = runner_cls + return runner_cls + + return wrap + + @classmethod + def register_path(cls, name, path): + r"""Register a path to registry with key 'name' + + Args: + name: Key with which the path will be registered. + + Usage: + + from video_llama.common.registry import registry + """ + assert isinstance(path, str), "All path must be str." + if name in cls.mapping["paths"]: + raise KeyError("Name '{}' already registered.".format(name)) + cls.mapping["paths"][name] = path + + @classmethod + def register(cls, name, obj): + r"""Register an item to registry with key 'name' + + Args: + name: Key with which the item will be registered. + + Usage:: + + from video_llama.common.registry import registry + + registry.register("config", {}) + """ + path = name.split(".") + current = cls.mapping["state"] + + for part in path[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + current[path[-1]] = obj + + # @classmethod + # def get_trainer_class(cls, name): + # return cls.mapping["trainer_name_mapping"].get(name, None) + + @classmethod + def get_builder_class(cls, name): + return cls.mapping["builder_name_mapping"].get(name, None) + + @classmethod + def get_model_class(cls, name): + return cls.mapping["model_name_mapping"].get(name, None) + + @classmethod + def get_task_class(cls, name): + return cls.mapping["task_name_mapping"].get(name, None) + + @classmethod + def get_processor_class(cls, name): + return cls.mapping["processor_name_mapping"].get(name, None) + + @classmethod + def get_lr_scheduler_class(cls, name): + return cls.mapping["lr_scheduler_name_mapping"].get(name, None) + + @classmethod + def get_runner_class(cls, name): + return cls.mapping["runner_name_mapping"].get(name, None) + + @classmethod + def list_runners(cls): + return sorted(cls.mapping["runner_name_mapping"].keys()) + + @classmethod + def list_models(cls): + return sorted(cls.mapping["model_name_mapping"].keys()) + + @classmethod + def list_tasks(cls): + return sorted(cls.mapping["task_name_mapping"].keys()) + + @classmethod + def list_processors(cls): + return sorted(cls.mapping["processor_name_mapping"].keys()) + + @classmethod + def list_lr_schedulers(cls): + return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) + + @classmethod + def list_datasets(cls): + return sorted(cls.mapping["builder_name_mapping"].keys()) + + @classmethod + def get_path(cls, name): + return cls.mapping["paths"].get(name, None) + + @classmethod + def get(cls, name, default=None, no_warning=False): + r"""Get an item from registry with key 'name' + + Args: + name (string): Key whose value needs to be retrieved. + default: If passed and key is not in registry, default value will + be returned with a warning. Default: None + no_warning (bool): If passed as True, warning when key doesn't exist + will not be generated. Useful for MMF's + internal operations. Default: False + """ + original_name = name + name = name.split(".") + value = cls.mapping["state"] + for subname in name: + value = value.get(subname, default) + if value is default: + break + + if ( + "writer" in cls.mapping["state"] + and value == default + and no_warning is False + ): + cls.mapping["state"]["writer"].warning( + "Key {} is not present in registry, returning default value " + "of {}".format(original_name, default) + ) + return value + + @classmethod + def unregister(cls, name): + r"""Remove an item from registry with key 'name' + + Args: + name: Key which needs to be removed. + Usage:: + + from mmf.common.registry import registry + + config = registry.unregister("config") + """ + return cls.mapping["state"].pop(name, None) + + +registry = Registry() diff --git a/video_llama/common/utils.py b/video_llama/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1768fcdfd73b057877a7b0a7c1f10a3aa057caa --- /dev/null +++ b/video_llama/common/utils.py @@ -0,0 +1,424 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import io +import json +import logging +import os +import pickle +import re +import shutil +import urllib +import urllib.error +import urllib.request +from typing import Optional +from urllib.parse import urlparse + +import numpy as np +import pandas as pd +import yaml +from iopath.common.download import download +from iopath.common.file_io import file_lock, g_pathmgr +from video_llama.common.registry import registry +from torch.utils.model_zoo import tqdm +from torchvision.datasets.utils import ( + check_integrity, + download_file_from_google_drive, + extract_archive, +) + + +def now(): + from datetime import datetime + + return datetime.now().strftime("%Y%m%d%H%M")[:-1] + + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def get_cache_path(rel_path): + return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) + + +def get_abs_path(rel_path): + return os.path.join(registry.get_path("library_root"), rel_path) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +# The following are adapted from torchvision and vissl +# torchvision: https://github.com/pytorch/vision +# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + print(f"Error creating directory: {dir_path}") + return is_success + + +def get_redirected_url(url: str): + """ + Given a URL, returns the URL it redirects to or the + original URL in case of no indirection + """ + import requests + + with requests.Session() as session: + with session.get(url, stream=True, allow_redirects=True) as response: + if response.history: + return response.url + else: + return url + + +def to_google_drive_download_url(view_url: str) -> str: + """ + Utility function to transform a view URL of google drive + to a download URL for google drive + Example input: + https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view + Example output: + https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp + """ + splits = view_url.split("/") + assert splits[-1] == "view" + file_id = splits[-2] + return f"https://drive.google.com/uc?export=download&id={file_id}" + + +def download_google_drive_url(url: str, output_path: str, output_file_name: str): + """ + Download a file from google drive + Downloading an URL from google drive requires confirmation when + the file of the size is too big (google drive notifies that + anti-viral checks cannot be performed on such files) + """ + import requests + + with requests.Session() as session: + + # First get the confirmation token and append it to the URL + with session.get(url, stream=True, allow_redirects=True) as response: + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + url = url + "&confirm=" + v + + # Then download the content of the file + with session.get(url, stream=True, verify=True) as response: + makedir(output_path) + path = os.path.join(output_path, output_file_name) + total_size = int(response.headers.get("Content-length", 0)) + with open(path, "wb") as file: + from tqdm import tqdm + + with tqdm(total=total_size) as progress_bar: + for block in response.iter_content( + chunk_size=io.DEFAULT_BUFFER_SIZE + ): + file.write(block) + progress_bar.update(len(block)) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen( + urllib.request.Request(url, headers={"User-Agent": "vissl"}) + ) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def download_url( + url: str, + root: str, + filename: Optional[str] = None, + md5: Optional[str] = None, +) -> None: + """Download a file from a url and place it in root. + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. + If None, use the basename of the URL. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + makedir(root) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + # expand redirect chain if needed + url = get_redirected_url(url) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print( + "Failed download. Trying https -> http instead." + " Downloading " + url + " to " + fpath + ) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished) + + +def cache_url(url: str, cache_dir: str) -> str: + """ + This implementation downloads the remote resource and caches it locally. + The resource will only be downloaded if not previously requested. + """ + parsed_url = urlparse(url) + dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) + makedir(dirname) + filename = url.split("/")[-1] + cached = os.path.join(dirname, filename) + with file_lock(cached): + if not os.path.isfile(cached): + logging.info(f"Downloading {url} to {cached} ...") + cached = download(url, dirname, filename=filename) + logging.info(f"URL {url} cached in {cached}") + return cached + + +# TODO (prigoyal): convert this into RAII-style API +def create_file_symlink(file1, file2): + """ + Simply create the symlinks for a given file1 to file2. + Useful during model checkpointing to symlinks to the + latest successful checkpoint. + """ + try: + if g_pathmgr.exists(file2): + g_pathmgr.rm(file2) + g_pathmgr.symlink(file1, file2) + except Exception as e: + logging.info(f"Could NOT create symlink. Error: {e}") + + +def save_file(data, filename, append_to_json=True, verbose=True): + """ + Common i/o utility to handle saving data to various file formats. + Supported: + .pkl, .pickle, .npy, .json + Specifically for .json, users have the option to either append (default) + or rewrite by passing in Boolean value to append_to_json. + """ + if verbose: + logging.info(f"Saving data to file: {filename}") + file_ext = os.path.splitext(filename)[1] + if file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "wb") as fopen: + pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) + elif file_ext == ".npy": + with g_pathmgr.open(filename, "wb") as fopen: + np.save(fopen, data) + elif file_ext == ".json": + if append_to_json: + with g_pathmgr.open(filename, "a") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + else: + with g_pathmgr.open(filename, "w") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "w") as fopen: + dump = yaml.dump(data) + fopen.write(dump) + fopen.flush() + else: + raise Exception(f"Saving {file_ext} is not supported yet") + + if verbose: + logging.info(f"Saved data to file: {filename}") + + +def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): + """ + Common i/o utility to handle loading data from various file formats. + Supported: + .pkl, .pickle, .npy, .json + For the npy files, we support reading the files in mmap_mode. + If the mmap_mode of reading is not successful, we load data without the + mmap_mode. + """ + if verbose: + logging.info(f"Loading data from file: {filename}") + + file_ext = os.path.splitext(filename)[1] + if file_ext == ".txt": + with g_pathmgr.open(filename, "r") as fopen: + data = fopen.readlines() + elif file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "rb") as fopen: + data = pickle.load(fopen, encoding="latin1") + elif file_ext == ".npy": + if mmap_mode: + try: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load( + fopen, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + except ValueError as e: + logging.info( + f"Could not mmap {filename}: {e}. Trying without g_pathmgr" + ) + data = np.load( + filename, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + logging.info("Successfully loaded without g_pathmgr") + except Exception: + logging.info("Could not mmap without g_pathmgr. Trying without mmap") + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + else: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + elif file_ext == ".json": + with g_pathmgr.open(filename, "r") as fopen: + data = json.load(fopen) + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "r") as fopen: + data = yaml.load(fopen, Loader=yaml.FullLoader) + elif file_ext == ".csv": + with g_pathmgr.open(filename, "r") as fopen: + data = pd.read_csv(fopen) + else: + raise Exception(f"Reading from {file_ext} is not supported yet") + return data + + +def abspath(resource_path: str): + """ + Make a path absolute, but take into account prefixes like + "http://" or "manifold://" + """ + regex = re.compile(r"^\w+://") + if regex.match(resource_path) is None: + return os.path.abspath(resource_path) + else: + return resource_path + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + logging.info(f"Error creating directory: {dir_path}") + return is_success + + +def is_url(input_url): + """ + Check if an input string is a url. look for http(s):// and ignoring the case + """ + is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None + return is_url + + +def cleanup_dir(dir): + """ + Utility for deleting a directory. Useful for cleaning the storage space + that contains various training artifacts like checkpoints, data etc. + """ + if os.path.exists(dir): + logging.info(f"Deleting directory: {dir}") + shutil.rmtree(dir) + logging.info(f"Deleted contents of directory: {dir}") + + +def get_file_size(filename): + """ + Given a file, get the size of file in MB + """ + size_in_mb = os.path.getsize(filename) / float(1024**2) + return size_in_mb diff --git a/video_llama/configs/datasets/cc_sbu/align.yaml b/video_llama/configs/datasets/cc_sbu/align.yaml new file mode 100644 index 0000000000000000000000000000000000000000..180ea8ff0a3548219165e8864d4a59a951206298 --- /dev/null +++ b/video_llama/configs/datasets/cc_sbu/align.yaml @@ -0,0 +1,5 @@ +datasets: + cc_sbu_align: + data_type: images + build_info: + storage: /path/to/cc_sbu_align_dataset diff --git a/video_llama/configs/datasets/cc_sbu/defaults.yaml b/video_llama/configs/datasets/cc_sbu/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..359de601d2511acd6cac34e7ee1c20a1dfe9ae04 --- /dev/null +++ b/video_llama/configs/datasets/cc_sbu/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + cc_sbu: + data_type: images + build_info: + storage: /path/to/cc_sbu_dataset/{00000..00001}.tar diff --git a/video_llama/configs/datasets/instruct/llava_instruct.yaml b/video_llama/configs/datasets/instruct/llava_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0ec4a938e299f98f6d84104c909da210385003af --- /dev/null +++ b/video_llama/configs/datasets/instruct/llava_instruct.yaml @@ -0,0 +1,6 @@ +datasets: + llava_instruct: + data_type: image + build_info: + anno_dir: /path/llava_instruct_150k.json + videos_dir: /path/train2014/train2014/ diff --git a/video_llama/configs/datasets/instruct/webvid_instruct.yaml b/video_llama/configs/datasets/instruct/webvid_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9619106ad8cb1000c3c40fa48672fb9247988d74 --- /dev/null +++ b/video_llama/configs/datasets/instruct/webvid_instruct.yaml @@ -0,0 +1,6 @@ +datasets: + webvid_instruct: + data_type: image + build_info: + anno_dir: /path/webvid_align/videochat_instruct_11k.json + videos_dir: /path/webvid_align/videos/ diff --git a/video_llama/configs/datasets/laion/defaults.yaml b/video_llama/configs/datasets/laion/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7dfff3ba891d96a136510f34b1dcf1b774705f4a --- /dev/null +++ b/video_llama/configs/datasets/laion/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + laion: + data_type: images + build_info: + storage: path/laion/laion_dataset/{00000..00001}.tar diff --git a/video_llama/configs/datasets/webvid/defaults.yaml b/video_llama/configs/datasets/webvid/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..046ae32dde61e2d79d1f519e1c8c653d8e2b5886 --- /dev/null +++ b/video_llama/configs/datasets/webvid/defaults.yaml @@ -0,0 +1,6 @@ +datasets: + webvid: + data_type: video + build_info: + anno_dir: path/webvid/webvid_tain_data/annotations/ + videos_dir: path//webvid/webvid_tain_data/videos/ diff --git a/video_llama/configs/default.yaml b/video_llama/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ff5a6a23fa2e3914938631b96c71fdf723dbbc10 --- /dev/null +++ b/video_llama/configs/default.yaml @@ -0,0 +1,5 @@ +env: + # For default users + # cache_root: "cache" + # For internal use with persistent storage + cache_root: "/export/home/.cache/minigpt4" diff --git a/video_llama/configs/models/minigpt4.yaml b/video_llama/configs/models/minigpt4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..358c3f5f7b53251c607ea490ee262a890e531dc6 --- /dev/null +++ b/video_llama/configs/models/minigpt4.yaml @@ -0,0 +1,33 @@ +model: + arch: mini_gpt4 + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + freeze_qformer: True + + # Q-Former + num_query_token: 32 + + # Vicuna + llama_model: "ckpt/vicuna-13b/" + + # generation configs + prompt: "" + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/video_llama/configs/models/video_llama.yaml b/video_llama/configs/models/video_llama.yaml new file mode 100644 index 0000000000000000000000000000000000000000..27ce07c3fbaa5a867279572c5b7d1dc31d469791 --- /dev/null +++ b/video_llama/configs/models/video_llama.yaml @@ -0,0 +1,36 @@ +model: + arch: video_llama + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + freeze_qformer: True + + # Q-Former + num_query_token: 32 + + # Vicuna + llama_model: "ckpt/vicuna-7b/" + + # generation configs + prompt: "" + +preprocess: + vis_processor: + train: + name: "alpro_video_train" + image_size: 224 + n_frms: 8 + eval: + name: "alpro_video_eval" + image_size: 224 + n_frms: 8 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + \ No newline at end of file diff --git a/video_llama/conversation/__init__.py b/video_llama/conversation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/video_llama/conversation/conversation_video.py b/video_llama/conversation/conversation_video.py new file mode 100644 index 0000000000000000000000000000000000000000..cd96a7a275f691519cd86200d7ed178d7cd2b75f --- /dev/null +++ b/video_llama/conversation/conversation_video.py @@ -0,0 +1,248 @@ +""" +Conversation prompt template of Video-LLaMA. +Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py +""" +import argparse +import time +from PIL import Image + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer +from transformers import StoppingCriteria, StoppingCriteriaList + +import dataclasses +from enum import auto, Enum +from typing import List, Tuple, Any +import os +from video_llama.common.registry import registry +from video_llama.processors.video_processor import ToTHWC,ToUint8,load_video +from video_llama.processors import Blip2ImageEvalProcessor +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + # system_img: List[Image.Image] = [] + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + # system_img=self.system_img, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + conv_id=self.conv_id) + + def dict(self): + return { + "system": self.system, + # "system_img": self.system_img, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id, + } + + +class StoppingCriteriaSub(StoppingCriteria): + + def __init__(self, stops=[], encounters=1): + super().__init__() + self.stops = stops + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + for stop in self.stops: + if torch.all((stop == input_ids[0][-len(stop):])).item(): + return True + + return False + + +CONV_VISION = Conversation( + system="Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions.", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +default_conversation = Conversation( + system="", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +class Chat: + def __init__(self, model, vis_processor, device='cuda:0'): + self.device = device + self.model = model + self.vis_processor = vis_processor + self.image_vis_processor = Blip2ImageEvalProcessor() + stop_words_ids = [torch.tensor([835]).to(self.device), + torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. + self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + + def ask(self, text, conv): + if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ + and ('' in conv.messages[-1][1] or '' in conv.messages[-1][1]): # last message is image. + conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) + else: + conv.append_message(conv.roles[0], text) + + def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): + conv.append_message(conv.roles[1], None) + embs = self.get_context_emb(conv, img_list) + + current_max_len = embs.shape[1] + max_new_tokens + if current_max_len - max_length > 0: + print('Warning: The number of tokens in current conversation exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, current_max_len - max_length) + + embs = embs[:, begin_idx:] + + outputs = self.model.llama_model.generate( + inputs_embeds=embs, + max_new_tokens=max_new_tokens, + stopping_criteria=self.stopping_criteria, + num_beams=num_beams, + do_sample=True, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + temperature=temperature, + ) + output_token = outputs[0] + if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it + output_token = output_token[1:] + if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it + output_token = output_token[1:] + output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) + output_text = output_text.split('###')[0] # remove the stop sign '###' + output_text = output_text.split('Assistant:')[-1].strip() + conv.messages[-1][1] = output_text + return output_text, output_token.cpu().numpy() + + def upload_video(self, video, conv, img_list): + + msg = "" + if isinstance(video, str): # is a video path + ext = os.path.splitext(video)[-1].lower() + print(video) + # image = self.vis_processor(image).unsqueeze(0).to(self.device) + video, msg = load_video( + video_path=video, + n_frms=8, + height=224, + width=224, + sampling ="uniform", return_msg = True + ) + video = self.vis_processor.transform(video) + video = video.unsqueeze(0).to(self.device) + # print(image) + else: + raise NotImplementedError + + image_emb, _ = self.model.encode_img(video) + img_list.append(image_emb) + conv.append_message(conv.roles[0], " "+ msg) + return "Received." + + def upload_img(self, image, conv, img_list): + + msg = "" + if isinstance(image, str): # is a image path + raw_image = Image.open(image).convert('RGB') # 增加一个时间维度 + image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device) + elif isinstance(image, Image.Image): + raw_image = image + image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device) + elif isinstance(image, torch.Tensor): + if len(image.shape) == 3: + image = image.unsqueeze(0) + image = image.to(self.device) + else: + raise NotImplementedError + + image_emb, _ = self.model.encode_img(image) + img_list.append(image_emb) + # Todo msg="" + conv.append_message(conv.roles[0], " "+ msg) + + return "Received." + + def get_context_emb(self, conv, img_list): + prompt = conv.get_prompt() + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.model.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids + # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + diff --git a/video_llama/datasets/__init__.py b/video_llama/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/video_llama/datasets/builders/__init__.py b/video_llama/datasets/builders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b160d0b8ad5793e368d8b2d26ff9829fa3ddd9a --- /dev/null +++ b/video_llama/datasets/builders/__init__.py @@ -0,0 +1,77 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from video_llama.datasets.builders.base_dataset_builder import load_dataset_config +from video_llama.datasets.builders.image_text_pair_builder import ( + CCSBUBuilder, + LaionBuilder, + CCSBUAlignBuilder +) +from video_llama.datasets.builders.video_caption_builder import WebvidBuilder +from video_llama.common.registry import registry +from video_llama.datasets.builders.instruct_builder import WebvidInstruct_Builder,LlavaInstruct_Builder +__all__ = [ + "CCSBUBuilder", + "LaionBuilder", + "CCSBUAlignBuilder", + "WebvidBuilder", + "LlavaInstruct_Builder", + "WebvidInstruct_Builder" + +] + + +def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): + """ + Example + + >>> dataset = load_dataset("coco_caption", cfg=None) + >>> splits = dataset.keys() + >>> print([len(dataset[split]) for split in splits]) + + """ + if cfg_path is None: + cfg = None + else: + cfg = load_dataset_config(cfg_path) + + try: + builder = registry.get_builder_class(name)(cfg) + except TypeError: + print( + f"Dataset {name} not found. Available datasets:\n" + + ", ".join([str(k) for k in dataset_zoo.get_names()]) + ) + exit(1) + + if vis_path is not None: + if data_type is None: + # use default data type in the config + data_type = builder.config.data_type + + assert ( + data_type in builder.config.build_info + ), f"Invalid data_type {data_type} for {name}." + + builder.config.build_info.get(data_type).storage = vis_path + + dataset = builder.build_datasets() + return dataset + + +class DatasetZoo: + def __init__(self) -> None: + self.dataset_zoo = { + k: list(v.DATASET_CONFIG_DICT.keys()) + for k, v in sorted(registry.mapping["builder_name_mapping"].items()) + } + + def get_names(self): + return list(self.dataset_zoo.keys()) + + +dataset_zoo = DatasetZoo() diff --git a/video_llama/datasets/builders/base_dataset_builder.py b/video_llama/datasets/builders/base_dataset_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..86c2cf688e9bcac67138aa32c58927e4a5ddebba --- /dev/null +++ b/video_llama/datasets/builders/base_dataset_builder.py @@ -0,0 +1,236 @@ +""" + This file is from + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os +import shutil +import warnings + +from omegaconf import OmegaConf +import torch.distributed as dist +from torchvision.datasets.utils import download_url + +import video_llama.common.utils as utils +from video_llama.common.dist_utils import is_dist_avail_and_initialized, is_main_process +from video_llama.common.registry import registry +from video_llama.processors.base_processor import BaseProcessor + + + +class BaseDatasetBuilder: + train_dataset_cls, eval_dataset_cls = None, None + + def __init__(self, cfg=None): + super().__init__() + + if cfg is None: + # help to create datasets from default config. + self.config = load_dataset_config(self.default_config_path()) + elif isinstance(cfg, str): + self.config = load_dataset_config(cfg) + else: + # when called from task.build_dataset() + self.config = cfg + + self.data_type = self.config.data_type + + self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + if is_main_process(): + self._download_data() + + if is_dist_avail_and_initialized(): + dist.barrier() + + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + datasets = self.build() # dataset['train'/'val'/'test'] + + return datasets + + def build_processors(self): + vis_proc_cfg = self.config.get("vis_processor") + txt_proc_cfg = self.config.get("text_processor") + + if vis_proc_cfg is not None: + vis_train_cfg = vis_proc_cfg.get("train") + vis_eval_cfg = vis_proc_cfg.get("eval") + + self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) + self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + + self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) + self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) + + @staticmethod + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else None + ) + + @classmethod + def default_config_path(cls, type="default"): + return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) + + def _download_data(self): + self._download_ann() + self._download_vis() + + def _download_ann(self): + """ + Download annotation files if necessary. + All the vision-language datasets should have annotations of unified format. + + storage_path can be: + (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. + (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. + + Local annotation paths should be relative. + """ + anns = self.config.build_info.annotations + + splits = anns.keys() + + cache_root = registry.get_path("cache_root") + + for split in splits: + info = anns[split] + + urls, storage_paths = info.get("url", None), info.storage + + if isinstance(urls, str): + urls = [urls] + if isinstance(storage_paths, str): + storage_paths = [storage_paths] + + assert len(urls) == len(storage_paths) + + for url_or_filename, storage_path in zip(urls, storage_paths): + # if storage_path is relative, make it full by prefixing with cache_root. + if not os.path.isabs(storage_path): + storage_path = os.path.join(cache_root, storage_path) + + dirname = os.path.dirname(storage_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + if os.path.isfile(url_or_filename): + src, dst = url_or_filename, storage_path + if not os.path.exists(dst): + shutil.copyfile(src=src, dst=dst) + else: + logging.info("Using existing file {}.".format(dst)) + else: + if os.path.isdir(storage_path): + # if only dirname is provided, suffix with basename of URL. + raise ValueError( + "Expecting storage_path to be a file path, got directory {}".format( + storage_path + ) + ) + else: + filename = os.path.basename(storage_path) + + download_url(url=url_or_filename, root=dirname, filename=filename) + + def _download_vis(self): + + storage_path = self.config.build_info.get(self.data_type).storage + storage_path = utils.get_cache_path(storage_path) + + if not os.path.exists(storage_path): + warnings.warn( + f""" + The specified path {storage_path} for visual inputs does not exist. + Please provide a correct path to the visual inputs or + refer to datasets/download_scripts/README.md for downloading instructions. + """ + ) + + def build(self): + """ + Create by split datasets inheriting torch.utils.data.Datasets. + + # build() can be dataset-specific. Overwrite to customize. + """ + self.build_processors() + + build_info = self.config.build_info + + ann_info = build_info.annotations + vis_info = build_info.get(self.data_type) + + datasets = dict() + for split in ann_info.keys(): + if split not in ["train", "val", "test"]: + continue + + is_train = split == "train" + + # processors + vis_processor = ( + self.vis_processors["train"] + if is_train + else self.vis_processors["eval"] + ) + text_processor = ( + self.text_processors["train"] + if is_train + else self.text_processors["eval"] + ) + + # annotation path + ann_paths = ann_info.get(split).storage + if isinstance(ann_paths, str): + ann_paths = [ann_paths] + + abs_ann_paths = [] + for ann_path in ann_paths: + if not os.path.isabs(ann_path): + ann_path = utils.get_cache_path(ann_path) + abs_ann_paths.append(ann_path) + ann_paths = abs_ann_paths + + # visual data storage path + vis_path = os.path.join(vis_info.storage, split) + + if not os.path.isabs(vis_path): + # vis_path = os.path.join(utils.get_cache_path(), vis_path) + vis_path = utils.get_cache_path(vis_path) + + if not os.path.exists(vis_path): + warnings.warn("storage path {} does not exist.".format(vis_path)) + + # create datasets + dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls + datasets[split] = dataset_cls( + vis_processor=vis_processor, + text_processor=text_processor, + ann_paths=ann_paths, + vis_root=vis_path, + ) + + return datasets + + +def load_dataset_config(cfg_path): + cfg = OmegaConf.load(cfg_path).datasets + cfg = cfg[list(cfg.keys())[0]] + + return cfg diff --git a/video_llama/datasets/builders/image_text_pair_builder.py b/video_llama/datasets/builders/image_text_pair_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..8f93bf8f0dd51318c01940f07dc10e9dda2dd275 --- /dev/null +++ b/video_llama/datasets/builders/image_text_pair_builder.py @@ -0,0 +1,106 @@ +import os +import logging +import warnings + +from video_llama.common.registry import registry +from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from video_llama.datasets.datasets.laion_dataset import LaionDataset +from video_llama.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset + + +@registry.register_builder("cc_sbu") +class CCSBUBuilder(BaseDatasetBuilder): + train_dataset_cls = CCSBUDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("laion") +class LaionBuilder(BaseDatasetBuilder): + train_dataset_cls = LaionDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("cc_sbu_align") +class CCSBUAlignBuilder(BaseDatasetBuilder): + train_dataset_cls = CCSBUAlignDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/cc_sbu/align.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + storage_path = build_info.storage + + datasets = dict() + + if not os.path.exists(storage_path): + warnings.warn("storage path {} does not exist.".format(storage_path)) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_paths=[os.path.join(storage_path, 'filter_cap.json')], + vis_root=os.path.join(storage_path, 'image'), + ) + + return datasets + diff --git a/video_llama/datasets/builders/instruct_builder.py b/video_llama/datasets/builders/instruct_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..b95238785386af934721d65cc7859d60f57023ae --- /dev/null +++ b/video_llama/datasets/builders/instruct_builder.py @@ -0,0 +1,78 @@ +import os +import logging +import warnings + +from video_llama.common.registry import registry +from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from video_llama.datasets.datasets.laion_dataset import LaionDataset +from video_llama.datasets.datasets.llava_instruct_dataset import Instruct_Dataset +from video_llama.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset + +@registry.register_builder("instruct") +class Instruct_Builder(BaseDatasetBuilder): + train_dataset_cls = Instruct_Dataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + datasets = dict() + split = "train" + + build_info = self.config.build_info + dataset_cls = self.train_dataset_cls + if self.config.num_video_query_token: + num_video_query_token = self.config.num_video_query_token + else: + num_video_query_token = 32 + + if self.config.tokenizer_name: + tokenizer_name = self.config.tokenizer_name + else: + tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/' + + + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + vis_root=build_info.videos_dir, + ann_root=build_info.anno_dir, + num_video_query_token = num_video_query_token, + tokenizer_name = tokenizer_name, + data_type = self.config.data_type + ) + + return datasets + +@registry.register_builder("webvid_instruct") +class WebvidInstruct_Builder(Instruct_Builder): + train_dataset_cls = Video_Instruct_Dataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/instruct/webvid_instruct.yaml", + } + +@registry.register_builder("webvid_instruct_zh") +class WebvidInstruct_zh_Builder(Instruct_Builder): + train_dataset_cls = Video_Instruct_Dataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/instruct/webvid_instruct.yaml", + } + + + +@registry.register_builder("llava_instruct") +class LlavaInstruct_Builder(Instruct_Builder): + train_dataset_cls = Instruct_Dataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/instruct/llava_instruct.yaml", + } + diff --git a/video_llama/datasets/builders/video_caption_builder.py b/video_llama/datasets/builders/video_caption_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..e73ef9db75c4699e1a763b459a7c96c1147e192b --- /dev/null +++ b/video_llama/datasets/builders/video_caption_builder.py @@ -0,0 +1,34 @@ +import os +import logging +import warnings + +from video_llama.common.registry import registry +from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from video_llama.datasets.datasets.webvid_datasets import WebvidDataset + +@registry.register_builder("webvid") +class WebvidBuilder(BaseDatasetBuilder): + train_dataset_cls = WebvidDataset + DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + datasets = dict() + split = "train" + + build_info = self.config.build_info + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + vis_root=build_info.videos_dir, + ann_root=build_info.anno_dir + ) + + return datasets \ No newline at end of file diff --git a/video_llama/datasets/data_utils.py b/video_llama/datasets/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe6a567bae667f00ef0ee1d4d9075649107b471 --- /dev/null +++ b/video_llama/datasets/data_utils.py @@ -0,0 +1,196 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import gzip +import logging +import os +import random as rnd +import tarfile +import zipfile +import random +from typing import List +from tqdm import tqdm + +import decord +from decord import VideoReader +import webdataset as wds +import numpy as np +import torch +from torch.utils.data.dataset import IterableDataset + +from video_llama.common.registry import registry +from video_llama.datasets.datasets.base_dataset import ConcatDataset + + +decord.bridge.set_bridge("torch") +MAX_INT = registry.get("MAX_INT") + + +class ChainDataset(wds.DataPipeline): + r"""Dataset for chaining multiple :class:`DataPipeline` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + def __init__(self, datasets: List[wds.DataPipeline]) -> None: + super().__init__() + self.datasets = datasets + self.prob = [] + self.names = [] + for dataset in self.datasets: + if hasattr(dataset, 'name'): + self.names.append(dataset.name) + else: + self.names.append('Unknown') + if hasattr(dataset, 'sample_ratio'): + self.prob.append(dataset.sample_ratio) + else: + self.prob.append(1) + logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") + + def __iter__(self): + datastreams = [iter(dataset) for dataset in self.datasets] + while True: + select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] + yield next(select_datastream) + + +def apply_to_sample(f, sample): + if len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + else: + return x + + return _apply(sample) + + +def move_to_cuda(sample): + def _move_to_cuda(tensor): + return tensor.cuda() + + return apply_to_sample(_move_to_cuda, sample) + + +def prepare_sample(samples, cuda_enabled=True): + if cuda_enabled: + samples = move_to_cuda(samples) + + # TODO fp16 support + + return samples + + +def reorg_datasets_by_split(datasets): + """ + Organizes datasets by split. + + Args: + datasets: dict of torch.utils.data.Dataset objects by name. + + Returns: + Dict of datasets by split {split_name: List[Datasets]}. + """ + # if len(datasets) == 1: + # return datasets[list(datasets.keys())[0]] + # else: + reorg_datasets = dict() + + # reorganize by split + for _, dataset in datasets.items(): + for split_name, dataset_split in dataset.items(): + if split_name not in reorg_datasets: + reorg_datasets[split_name] = [dataset_split] + else: + reorg_datasets[split_name].append(dataset_split) + + return reorg_datasets + + +def concat_datasets(datasets): + """ + Concatenates multiple datasets into a single dataset. + + It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support + generic IterableDataset because it requires creating separate samplers. + + Now only supports conctenating training datasets and assuming validation and testing + have only a single dataset. This is because metrics should not be computed on the concatenated + datasets. + + Args: + datasets: dict of torch.utils.data.Dataset objects by split. + + Returns: + Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, + "val" and "test" remain the same. + + If the input training datasets contain both map-style and DataPipeline datasets, returns + a tuple, where the first element is a concatenated map-style dataset and the second + element is a chained DataPipeline dataset. + + """ + # concatenate datasets in the same split + for split_name in datasets: + if split_name != "train": + assert ( + len(datasets[split_name]) == 1 + ), "Do not support multiple {} datasets.".format(split_name) + datasets[split_name] = datasets[split_name][0] + else: + iterable_datasets, map_datasets = [], [] + for dataset in datasets[split_name]: + if isinstance(dataset, wds.DataPipeline): + logging.info( + "Dataset {} is IterableDataset, can't be concatenated.".format( + dataset + ) + ) + iterable_datasets.append(dataset) + elif isinstance(dataset, IterableDataset): + raise NotImplementedError( + "Do not support concatenation of generic IterableDataset." + ) + else: + map_datasets.append(dataset) + + # if len(iterable_datasets) > 0: + # concatenate map-style datasets and iterable-style datasets separately + if len(iterable_datasets) > 1: + chained_datasets = ( + ChainDataset(iterable_datasets) + ) + elif len(iterable_datasets) == 1: + chained_datasets = iterable_datasets[0] + else: + chained_datasets = None + + concat_datasets = ( + ConcatDataset(map_datasets) if len(map_datasets) > 0 else None + ) + + train_datasets = concat_datasets, chained_datasets + train_datasets = tuple([x for x in train_datasets if x is not None]) + train_datasets = ( + train_datasets[0] if len(train_datasets) == 1 else train_datasets + ) + + datasets[split_name] = train_datasets + + return datasets + diff --git a/video_llama/datasets/datasets/__init__.py b/video_llama/datasets/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/video_llama/datasets/datasets/base_dataset.py b/video_llama/datasets/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2a8d0e21370129c0182cddc427eb293bbe5982 --- /dev/null +++ b/video_llama/datasets/datasets/base_dataset.py @@ -0,0 +1,68 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import json +from typing import Iterable + +from torch.utils.data import Dataset, ConcatDataset +from torch.utils.data.dataloader import default_collate + + +class BaseDataset(Dataset): + def __init__( + self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.annotation = [] + for ann_path in ann_paths: + self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __len__(self): + return len(self.annotation) + + def collater(self, samples): + return default_collate(samples) + + def set_processors(self, vis_processor, text_processor): + self.vis_processor = vis_processor + self.text_processor = text_processor + + def _add_instance_ids(self, key="instance_id"): + for idx, ann in enumerate(self.annotation): + ann[key] = str(idx) + + +class ConcatDataset(ConcatDataset): + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__(datasets) + + def collater(self, samples): + # TODO For now only supports datasets with same underlying collater implementations + + all_keys = set() + for s in samples: + all_keys.update(s) + + shared_keys = all_keys + for s in samples: + shared_keys = shared_keys & set(s.keys()) + + samples_shared_keys = [] + for s in samples: + samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) + + return self.datasets[0].collater(samples_shared_keys) diff --git a/video_llama/datasets/datasets/caption_datasets.py b/video_llama/datasets/datasets/caption_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..a78105896012b87174a547a365451d5d67fd8e93 --- /dev/null +++ b/video_llama/datasets/datasets/caption_datasets.py @@ -0,0 +1,85 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict + +from video_llama.datasets.datasets.base_dataset import BaseDataset +from PIL import Image + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "caption": ann["caption"], + "image": sample["image"], + } + ) + + +class CaptionDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + img_file = '{:0>12}.jpg'.format(ann["image_id"]) + image_path = os.path.join(self.vis_root, img_file) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = self.text_processor(ann["caption"]) + + return { + "image": image, + "text_input": caption, + "image_id": self.img_ids[ann["image_id"]], + } + + +class CaptionEvalDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + return { + "image": image, + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } diff --git a/video_llama/datasets/datasets/cc_sbu_dataset.py b/video_llama/datasets/datasets/cc_sbu_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..59311479552e55b0e6f7d9aec3d70b3d993f92d1 --- /dev/null +++ b/video_llama/datasets/datasets/cc_sbu_dataset.py @@ -0,0 +1,49 @@ +import os +from PIL import Image +import webdataset as wds +from video_llama.datasets.datasets.base_dataset import BaseDataset +from video_llama.datasets.datasets.caption_datasets import CaptionDataset + + +class CCSBUDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "image": sample[0], + "text_input": self.text_processor(sample[1]["caption"]), + "type":'image', + } + + +class CCSBUAlignDataset(CaptionDataset): + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + img_file = '{}.jpg'.format(ann["image_id"]) + image_path = os.path.join(self.vis_root, img_file) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = ann["caption"] + + return { + "image": image, + "text_input": caption, + "image_id": self.img_ids[ann["image_id"]], + "type":'image', + } \ No newline at end of file diff --git a/video_llama/datasets/datasets/dataloader_utils.py b/video_llama/datasets/datasets/dataloader_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2f574e24d2a32a18533a11492cfd481ff2cfbb --- /dev/null +++ b/video_llama/datasets/datasets/dataloader_utils.py @@ -0,0 +1,162 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import time +import random +import torch +from video_llama.datasets.data_utils import move_to_cuda +from torch.utils.data import DataLoader + + +class MultiIterLoader: + """ + A simple wrapper for iterating over multiple iterators. + + Args: + loaders (List[Loader]): List of Iterator loaders. + ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. + """ + + def __init__(self, loaders, ratios=None): + # assert all loaders has __next__ method + for loader in loaders: + assert hasattr( + loader, "__next__" + ), "Loader {} has no __next__ method.".format(loader) + + if ratios is None: + ratios = [1.0] * len(loaders) + else: + assert len(ratios) == len(loaders) + ratios = [float(ratio) / sum(ratios) for ratio in ratios] + + self.loaders = loaders + self.ratios = ratios + + def __next__(self): + # random sample from each loader by ratio + loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] + return next(self.loaders[loader_idx]) + + +class PrefetchLoader(object): + """ + Modified from https://github.com/ChenRocks/UNITER. + + overlap compute and cuda data transfer + (copied and then modified from nvidia apex) + """ + + def __init__(self, loader): + self.loader = loader + self.stream = torch.cuda.Stream() + + def __iter__(self): + loader_it = iter(self.loader) + self.preload(loader_it) + batch = self.next(loader_it) + while batch is not None: + is_tuple = isinstance(batch, tuple) + if is_tuple: + task, batch = batch + + if is_tuple: + yield task, batch + else: + yield batch + batch = self.next(loader_it) + + def __len__(self): + return len(self.loader) + + def preload(self, it): + try: + self.batch = next(it) + except StopIteration: + self.batch = None + return + # if record_stream() doesn't work, another option is to make sure + # device inputs are created on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, + # device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, + # device='cuda') + # Need to make sure the memory allocated for next_* is not still in use + # by the main stream at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.batch = move_to_cuda(self.batch) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this + # side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + def next(self, it): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is not None: + record_cuda_stream(batch) + self.preload(it) + return batch + + def __getattr__(self, name): + method = self.loader.__getattribute__(name) + return method + + +def record_cuda_stream(batch): + if isinstance(batch, torch.Tensor): + batch.record_stream(torch.cuda.current_stream()) + elif isinstance(batch, list) or isinstance(batch, tuple): + for t in batch: + record_cuda_stream(t) + elif isinstance(batch, dict): + for t in batch.values(): + record_cuda_stream(t) + else: + pass + + +class IterLoader: + """ + A wrapper to convert DataLoader as an infinite iterator. + + Modified from: + https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py + """ + + def __init__(self, dataloader: DataLoader, use_distributed: bool = False): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._use_distributed = use_distributed + self._epoch = 0 + + @property + def epoch(self) -> int: + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: + self._dataloader.sampler.set_epoch(self._epoch) + time.sleep(2) # Prevent possible deadlock during epoch transition + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __iter__(self): + return self + + def __len__(self): + return len(self._dataloader) diff --git a/video_llama/datasets/datasets/laion_dataset.py b/video_llama/datasets/datasets/laion_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1be30abb188e1afad6fe678ccbb367931a2b3d26 --- /dev/null +++ b/video_llama/datasets/datasets/laion_dataset.py @@ -0,0 +1,31 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import webdataset as wds +from video_llama.datasets.datasets.base_dataset import BaseDataset + + +class LaionDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "image": sample[0], + "text_input": self.text_processor(sample[1]["caption"]), + } + diff --git a/video_llama/datasets/datasets/llava_instruct_dataset.py b/video_llama/datasets/datasets/llava_instruct_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..105e0981581b7934c5df2bc53ecf03142cc4c969 --- /dev/null +++ b/video_llama/datasets/datasets/llava_instruct_dataset.py @@ -0,0 +1,228 @@ +import os +from video_llama.datasets.datasets.base_dataset import BaseDataset +from video_llama.datasets.datasets.caption_datasets import CaptionDataset +import pandas as pd +import decord +from decord import VideoReader +import random +import torch +from torch.utils.data.dataloader import default_collate +from PIL import Image +from typing import Dict, Optional, Sequence +import transformers +import pathlib +import json +from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer +from video_llama.conversation.conversation_video import Conversation,SeparatorStyle +DEFAULT_IMAGE_PATCH_TOKEN = '' +DEFAULT_IMAGE_TOKEN = "" +import copy +IGNORE_INDEX = -100 +image_conversation = Conversation( + system="", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) +IGNORE_INDEX = -100 + +class Instruct_Dataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image'): + """ + vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/) + ann_root (string): Root directory of video (e.g. webvid_eval/annotations/) + split (string): val or test + """ + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + data_path = pathlib.Path(ann_root) + with data_path.open(encoding='utf-8') as f: + self.annotation = json.load(f) + + self.vis_root = vis_root + self.resize_size = 224 + self.num_frm = 8 + self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.num_video_query_token = num_video_query_token + self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN] + + self.transform = AlproVideoTrainProcessor( + image_size=self.resize_size, n_frms = self.num_frm + ).transform + self.data_type = data_type + + def _get_image_path(self, sample): + rel_video_fp ='COCO_train2014_' + sample['image'] + full_video_fp = os.path.join(self.vis_root, rel_video_fp) + return full_video_fp + + def __getitem__(self, index): + num_retries = 10 # skip error videos + for _ in range(num_retries): + try: + sample = self.annotation[index] + + image_path = self._get_image_path(sample) + conversation_list = sample['conversations'] + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + # text = self.text_processor(text) + sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token) + data_dict = preprocess( + sources, + self.tokenizer) + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + data_dict['image'] = image + except: + print(f"Failed to load examples with image: {image_path}. " + f"Will randomly sample an example as a replacement.") + index = random.randint(0, len(self) - 1) + continue + break + else: + raise RuntimeError(f"Failed to fetch image after {num_retries} retries.") + # "image_id" is kept to stay compatible with the COCO evaluation format + return { + "image": image, + "text_input": data_dict["input_ids"], + "labels": data_dict["labels"], + "type":'image', + } + + def __len__(self): + return len(self.annotation) + + def collater(self, instances): + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("text_input", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + batch['conv_type'] = 'multi' + return batch + + +def preprocess_multimodal( + conversation_list: Sequence[str], + multimodal_cfg: dict, + cur_token_len: int, +) -> Dict: + # 将conversational list中 + is_multimodal = True + # image_token_len = multimodal_cfg['image_token_len'] + image_token_len = cur_token_len + + for sentence in conversation_list: + replace_token = ''+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len+'/' + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return [conversation_list] + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "###" + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = image_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = image_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=512, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{image_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], + tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len diff --git a/video_llama/datasets/datasets/video_instruct_dataset.py b/video_llama/datasets/datasets/video_instruct_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7de6e20d30d9b0d7280d706636e9849b7f02618c --- /dev/null +++ b/video_llama/datasets/datasets/video_instruct_dataset.py @@ -0,0 +1,253 @@ +import os +from video_llama.datasets.datasets.base_dataset import BaseDataset +from video_llama.datasets.datasets.caption_datasets import CaptionDataset +import pandas as pd +import decord +from decord import VideoReader +import random +import torch +from torch.utils.data.dataloader import default_collate +from PIL import Image +from typing import Dict, Optional, Sequence +import transformers +import pathlib +import json +from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer +import copy +from video_llama.processors import transforms_video,AlproVideoTrainProcessor +from torchvision import transforms +from video_llama.processors.video_processor import ToTHWC,ToUint8,load_video +from video_llama.conversation.conversation_video import Conversation,SeparatorStyle + +DEFAULT_IMAGE_PATCH_TOKEN = '' +video_conversation = Conversation( + system="", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) +IGNORE_INDEX = -100 + +class Video_Instruct_Dataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'video'): + """ + vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/) + ann_root (string): Root directory of video (e.g. webvid_eval/annotations/) + split (string): val or test + """ + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + data_path = pathlib.Path(ann_root) + with data_path.open(encoding='utf-8') as f: + self.annotation = json.load(f) + + self.num_video_query_token = num_video_query_token + self.vis_root = vis_root + self.resize_size = 224 + self.num_frm = 8 + self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN] + + self.transform = AlproVideoTrainProcessor( + image_size=self.resize_size, n_frms = self.num_frm + ).transform + self.data_type = data_type + + def _get_video_path(self, sample): + rel_video_fp = sample['video'] + full_video_fp = os.path.join(self.vis_root, rel_video_fp) + return full_video_fp + + def __getitem__(self, index): + num_retries = 10 # skip error videos + for _ in range(num_retries): + try: + sample = self.annotation[index] + + video_path = self._get_video_path(sample) + conversation_list = sample['QA'] + + video, msg = load_video( + video_path=video_path, + n_frms=self.num_frm, + height=self.resize_size, + width=self.resize_size, + sampling ="uniform", return_msg = True + ) + video = self.transform(video) + if 'cn' in self.data_type: + msg = "" + # 添加视频,以及msg到convsation list 0 + sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token,msg = msg) + new_sources = convert_source_vicuna_format(sources) + + data_dict = preprocess( + new_sources, + self.tokenizer) + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + # image exist in the data + data_dict['image'] = video + except: + print(f"Failed to load examples with video: {video_path}. " + f"Will randomly sample an example as a replacement.") + index = random.randint(0, len(self) - 1) + continue + break + else: + raise RuntimeError(f"Failed to fetch video after {num_retries} retries.") + # "image_id" is kept to stay compatible with the COCO evaluation format + return { + "image": video, + "text_input": data_dict["input_ids"], + "labels": data_dict["labels"], + "type":'video', + } + + def __len__(self): + return len(self.annotation) + + def collater(self, instances): + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("text_input", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + batch['conv_type'] = 'multi' + return batch + +def convert_source_vicuna_format(sources): + new_sources = [] + for source in sources: + new_source = [] + for i, sentence in enumerate(source): + role_0_msg = sentence['q'] + role_1_msg = sentence['a'] + new_source.append({ + 'from':'human', + 'value': role_0_msg, + }) + new_source.append({ + 'from':'gpt', + 'value': role_1_msg, + }) + new_sources.append(new_source) + return new_sources + +def preprocess_multimodal( + conversation_list: Sequence[str], + multimodal_cfg: dict, + cur_token_len: int, + msg='' +) -> Dict: + # 将conversational list中 + is_multimodal = True + # image_token_len = multimodal_cfg['image_token_len'] + image_token_len = cur_token_len + conversation_list[0]["q"] = " " + msg + conversation_list[0]["q"] + return [conversation_list] + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "###" + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = video_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = video_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=512, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{video_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], + tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len diff --git a/video_llama/datasets/datasets/webvid_datasets.py b/video_llama/datasets/datasets/webvid_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf6b9d6dff0d96b04d40a40c0051527f7d01842 --- /dev/null +++ b/video_llama/datasets/datasets/webvid_datasets.py @@ -0,0 +1,122 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from video_llama.datasets.datasets.base_dataset import BaseDataset +from video_llama.datasets.datasets.caption_datasets import CaptionDataset +import pandas as pd +import decord +from decord import VideoReader +import random +import torch +from torch.utils.data.dataloader import default_collate +class WebvidDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_root): + """ + vis_root (string): Root directory of video (e.g. webvid_eval/video/) + ann_root (string): Root directory of video (e.g. webvid_eval/annotations/) + split (string): val or test + """ + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + + # 读取一个路径下所有的 + + ts_df = [] + for file_name in os.listdir(ann_root): + if file_name.endswith('.csv'): + df = pd.read_csv(os.path.join(ann_root, file_name)) + ts_df.append(df) + + merged_df = pd.concat(ts_df) + self.annotation = merged_df + self.vis_root = vis_root + self.resize_size = 224 + self.num_frm = 8 + self.frm_sampling_strategy = 'headtail' + + def _get_video_path(self, sample): + rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') + full_video_fp = os.path.join(self.vis_root, rel_video_fp) + return full_video_fp + + def __getitem__(self, index): + num_retries = 10 # skip error videos + for _ in range(num_retries): + sample = self.annotation.iloc[index] + sample_dict = sample.to_dict() + video_id = sample_dict['videoid'] + + if 'name' in sample_dict.keys(): + text = sample_dict['name'].strip() + else: + raise NotImplementedError("Un-supported text annotation format.") + + # fetch video + video_path = self._get_video_path(sample_dict) + # if os.path.exists(video_path): + try: + video = self.vis_processor(video_path) + except: + print(f"Failed to load examples with video: {video_path}. " + f"Will randomly sample an example as a replacement.") + index = random.randint(0, len(self) - 1) + continue + caption = self.text_processor(text) + + # print(video.size()) + if video is None or caption is None \ + or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]): + print(f"Failed to load examples with video: {video_path}. " + f"Will randomly sample an example as a replacement.") + index = random.randint(0, len(self) - 1) + continue + else: + break + else: + raise RuntimeError(f"Failed to fetch video after {num_retries} retries.") + # "image_id" is kept to stay compatible with the COCO evaluation format + return { + "image": video, + "text_input": caption, + "type":'video', + } + + def __len__(self): + return len(self.annotation) + + # def collater(self, samples): + # new_result = {} + # new_result['image'] = default_collate( [sample["image"] for sample in samples]) + # new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples]) + # return new_result + +class WebvidDatasetEvalDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + vname = ann["video"] + video_path = os.path.join(self.vis_root, vname) + + video = self.vis_processor(video_path) + + return { + "video": video, + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } + + diff --git a/video_llama/models/Qformer.py b/video_llama/models/Qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4902165ec6574d89f04cbeb2141b018278324ca6 --- /dev/null +++ b/video_llama/models/Qformer.py @@ -0,0 +1,1217 @@ +""" +Adapted from salesforce@LAVIS. Below is the original copyright: + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/video_llama/models/__init__.py b/video_llama/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ddc54288cdc9d7e75b71da3f8597052f4f4837c --- /dev/null +++ b/video_llama/models/__init__.py @@ -0,0 +1,201 @@ +""" +Adapted from salesforce@LAVIS Vision-CAIR@MiniGPT-4. Below is the original copyright: + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import torch +from omegaconf import OmegaConf + +from video_llama.common.registry import registry +from video_llama.models.base_model import BaseModel +from video_llama.models.blip2 import Blip2Base +from video_llama.models.video_llama import VideoLLAMA +from video_llama.processors.base_processor import BaseProcessor + + +__all__ = [ + "load_model", + "BaseModel", + "Blip2Base", + "VideoLLAMA" +] + + +def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): + """ + Load supported models. + + To list all available models and types in registry: + >>> from video_llama.models import model_zoo + >>> print(model_zoo) + + Args: + name (str): name of the model. + model_type (str): type of the model. + is_eval (bool): whether the model is in eval mode. Default: False. + device (str): device to use. Default: "cpu". + checkpoint (str): path or to checkpoint. Default: None. + Note that expecting the checkpoint to have the same keys in state_dict as the model. + + Returns: + model (torch.nn.Module): model. + """ + + model = registry.get_model_class(name).from_pretrained(model_type=model_type) + + if checkpoint is not None: + model.load_checkpoint(checkpoint) + + if is_eval: + model.eval() + + if device == "cpu": + model = model.float() + + return model.to(device) + + +def load_preprocess(config): + """ + Load preprocessor configs and construct preprocessors. + + If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. + + Args: + config (dict): preprocessor configs. + + Returns: + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + + Key is "train" or "eval" for processors used in training and evaluation respectively. + """ + + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else BaseProcessor() + ) + + vis_processors = dict() + txt_processors = dict() + + vis_proc_cfg = config.get("vis_processor") + txt_proc_cfg = config.get("text_processor") + + if vis_proc_cfg is not None: + vis_train_cfg = vis_proc_cfg.get("train") + vis_eval_cfg = vis_proc_cfg.get("eval") + else: + vis_train_cfg = None + vis_eval_cfg = None + + vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) + vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + else: + txt_train_cfg = None + txt_eval_cfg = None + + txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) + txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) + + return vis_processors, txt_processors + + +def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): + """ + Load model and its related preprocessors. + + List all available models and types in registry: + >>> from video_llama.models import model_zoo + >>> print(model_zoo) + + Args: + name (str): name of the model. + model_type (str): type of the model. + is_eval (bool): whether the model is in eval mode. Default: False. + device (str): device to use. Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ + model_cls = registry.get_model_class(name) + + # load model + model = model_cls.from_pretrained(model_type=model_type) + + if is_eval: + model.eval() + + # load preprocess + cfg = OmegaConf.load(model_cls.default_config_path(model_type)) + if cfg is not None: + preprocess_cfg = cfg.preprocess + + vis_processors, txt_processors = load_preprocess(preprocess_cfg) + else: + vis_processors, txt_processors = None, None + logging.info( + f"""No default preprocess for model {name} ({model_type}). + This can happen if the model is not finetuned on downstream datasets, + or it is not intended for direct use without finetuning. + """ + ) + + if device == "cpu" or device == torch.device("cpu"): + model = model.float() + + return model.to(device), vis_processors, txt_processors + + +class ModelZoo: + """ + A utility class to create string representation of available model architectures and types. + + >>> from video_llama.models import model_zoo + >>> # list all available models + >>> print(model_zoo) + >>> # show total number of models + >>> print(len(model_zoo)) + """ + + def __init__(self) -> None: + self.model_zoo = { + k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) + for k, v in registry.mapping["model_name_mapping"].items() + } + + def __str__(self) -> str: + return ( + "=" * 50 + + "\n" + + f"{'Architectures':<30} {'Types'}\n" + + "=" * 50 + + "\n" + + "\n".join( + [ + f"{name:<30} {', '.join(types)}" + for name, types in self.model_zoo.items() + ] + ) + ) + + def __iter__(self): + return iter(self.model_zoo.items()) + + def __len__(self): + return sum([len(v) for v in self.model_zoo.values()]) + + +model_zoo = ModelZoo() diff --git a/video_llama/models/base_model.py b/video_llama/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..272ddd15129a83b6a5a0063553f512faca1f5612 --- /dev/null +++ b/video_llama/models/base_model.py @@ -0,0 +1,248 @@ +""" +Adapted from salesforce@LAVIS. Below is the original copyright: + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import numpy as np +import torch +import torch.nn as nn +from video_llama.common.dist_utils import download_cached_file, is_dist_avail_and_initialized +from video_llama.common.utils import get_abs_path, is_url +from omegaconf import OmegaConf + + +class BaseModel(nn.Module): + """Base class for models.""" + + def __init__(self): + super().__init__() + + @property + def device(self): + return list(self.parameters())[0].device + + def load_checkpoint(self, url_or_filename): + """ + Load from a finetuned checkpoint. + + This should expect no mismatch in the model keys and the checkpoint keys. + """ + + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + if "model" in checkpoint.keys(): + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + msg = self.load_state_dict(state_dict, strict=False) + + logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + @classmethod + def from_pretrained(cls, model_type): + """ + Build a pretrained model from default configuration file, specified by model_type. + + Args: + - model_type (str): model type, specifying architecture and checkpoints. + + Returns: + - model (nn.Module): pretrained or finetuned model, depending on the configuration. + """ + model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model + model = cls.from_config(model_cfg) + + return model + + @classmethod + def default_config_path(cls, model_type): + assert ( + model_type in cls.PRETRAINED_MODEL_CONFIG_DICT + ), "Unknown model type {}".format(model_type) + return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) + + def load_checkpoint_from_config(self, cfg, **kwargs): + """ + Load checkpoint as specified in the config file. + + If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. + When loading the pretrained model, each task-specific architecture may define their + own load_from_pretrained() method. + """ + load_finetuned = cfg.get("load_finetuned", True) + if load_finetuned: + finetune_path = cfg.get("finetuned", None) + assert ( + finetune_path is not None + ), "Found load_finetuned is True, but finetune_path is None." + self.load_checkpoint(url_or_filename=finetune_path) + else: + # load pre-trained weights + pretrain_path = cfg.get("pretrained", None) + assert "Found load_finetuned is False, but pretrain_path is None." + self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) + + def before_evaluation(self, **kwargs): + pass + + def show_n_params(self, return_str=True): + tot = 0 + for p in self.parameters(): + w = 1 + for x in p.shape: + w *= x + tot += w + if return_str: + if tot >= 1e6: + return "{:.1f}M".format(tot / 1e6) + else: + return "{:.1f}K".format(tot / 1e3) + else: + return tot + + +class BaseEncoder(nn.Module): + """ + Base class for primitive encoders, such as ViT, TimeSformer, etc. + """ + + def __init__(self): + super().__init__() + + def forward_features(self, samples, **kwargs): + raise NotImplementedError + + @property + def device(self): + return list(self.parameters())[0].device + + +class SharedQueueMixin: + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr : ptr + batch_size] = image_feats.T + self.text_queue[:, ptr : ptr + batch_size] = text_feats.T + + if idxs is not None: + idxs = concat_all_gather(idxs) + self.idx_queue[:, ptr : ptr + batch_size] = idxs.T + + ptr = (ptr + batch_size) % self.queue_size # move pointer + self.queue_ptr[0] = ptr + + +class MomentumDistilationMixin: + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip( + model_pair[0].parameters(), model_pair[1].parameters() + ): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip( + model_pair[0].parameters(), model_pair[1].parameters() + ): + param_m.data = param_m.data * self.momentum + param.data * ( + 1.0 - self.momentum + ) + + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [ + torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + torch.distributed.all_reduce(all_gradients) + return all_gradients[torch.distributed.get_rank()] + + +def all_gather_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + # tensor_all = GatherLayer.apply(tensors) + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + # if use distributed training + if not is_dist_avail_and_initialized(): + return tensor + + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) + ) + return torch.index_select(x, dim, order_index.to(x.device)) diff --git a/video_llama/models/blip2.py b/video_llama/models/blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..e6a42b5ee1016acf4ebeb82d2c04fe69fe2364f5 --- /dev/null +++ b/video_llama/models/blip2.py @@ -0,0 +1,222 @@ +""" +Adapted from salesforce@LAVIS. Below is the original copyright: + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import contextlib +import logging +import os +import time +import datetime + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.nn.functional as F + +import video_llama.common.dist_utils as dist_utils +from video_llama.common.dist_utils import download_cached_file +from video_llama.common.utils import is_url +from video_llama.common.logger import MetricLogger +from video_llama.models.base_model import BaseModel +from video_llama.models.Qformer import BertConfig, BertLMHeadModel +from video_llama.models.eva_vit import create_eva_vit_g +from transformers import BertTokenizer + + +class Blip2Base(BaseModel): + @classmethod + def init_tokenizer(cls): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + return tokenizer + + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + @classmethod + def init_vision_encoder( + cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision + ): + assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" + visual_encoder = create_eva_vit_g( + img_size, drop_path_rate, use_grad_checkpoint, precision + ) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + msg = self.load_state_dict(state_dict, strict=False) + + # logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +def compute_sim_matrix(model, data_loader, **kwargs): + k_test = kwargs.pop("k_test") + + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation:" + + logging.info("Computing features for evaluation...") + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_ids = [] + text_embeds = [] + text_atts = [] + for i in range(0, num_text, text_bs): + text = texts[i : min(num_text, i + text_bs)] + text_input = model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=35, + return_tensors="pt", + ).to(model.device) + text_feat = model.forward_text(text_input) + text_embed = F.normalize(model.text_proj(text_feat)) + text_embeds.append(text_embed) + text_ids.append(text_input.input_ids) + text_atts.append(text_input.attention_mask) + + text_embeds = torch.cat(text_embeds, dim=0) + text_ids = torch.cat(text_ids, dim=0) + text_atts = torch.cat(text_atts, dim=0) + + vit_feats = [] + image_embeds = [] + for samples in data_loader: + image = samples["image"] + + image = image.to(model.device) + image_feat, vit_feat = model.forward_image(image) + image_embed = model.vision_proj(image_feat) + image_embed = F.normalize(image_embed, dim=-1) + + vit_feats.append(vit_feat.cpu()) + image_embeds.append(image_embed) + + vit_feats = torch.cat(vit_feats, dim=0) + image_embeds = torch.cat(image_embeds, dim=0) + + sims_matrix = [] + for image_embed in image_embeds: + sim_q2t = image_embed @ text_embeds.t() + sim_i2t, _ = sim_q2t.max(0) + sims_matrix.append(sim_i2t) + sims_matrix = torch.stack(sims_matrix, dim=0) + + score_matrix_i2t = torch.full( + (len(data_loader.dataset.image), len(texts)), -100.0 + ).to(model.device) + + num_tasks = dist_utils.get_world_size() + rank = dist_utils.get_rank() + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[topk_idx], + text_atts=text_atts[topk_idx], + ).float() + score_matrix_i2t[start + i, topk_idx] = score + topk_sim + + sims_matrix = sims_matrix.t() + score_matrix_t2i = torch.full( + (len(texts), len(data_loader.dataset.image)), -100.0 + ).to(model.device) + + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[topk_idx.cpu()].to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[start + i].repeat(k_test, 1), + text_atts=text_atts[start + i].repeat(k_test, 1), + ).float() + score_matrix_t2i[start + i, topk_idx] = score + topk_sim + + if dist_utils.is_dist_avail_and_initialized(): + dist.barrier() + torch.distributed.all_reduce( + score_matrix_i2t, op=torch.distributed.ReduceOp.SUM + ) + torch.distributed.all_reduce( + score_matrix_t2i, op=torch.distributed.ReduceOp.SUM + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Evaluation time {}".format(total_time_str)) + + return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() diff --git a/video_llama/models/blip2_outputs.py b/video_llama/models/blip2_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..92d83a0556e6c5c3c0a603279f318605ae25d6d5 --- /dev/null +++ b/video_llama/models/blip2_outputs.py @@ -0,0 +1,111 @@ +""" +Adapted from salesforce@LAVIS. Below is the original copyright: + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.modeling_outputs import ( + ModelOutput, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) + + +@dataclass +class BlipSimilarity(ModelOutput): + sim_i2t: torch.FloatTensor = None + sim_t2i: torch.FloatTensor = None + + sim_i2t_m: Optional[torch.FloatTensor] = None + sim_t2i_m: Optional[torch.FloatTensor] = None + + sim_i2t_targets: Optional[torch.FloatTensor] = None + sim_t2i_targets: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipIntermediateOutput(ModelOutput): + """ + Data class for intermediate outputs of BLIP models. + + image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). + text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). + + image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). + text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). + + encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. + encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. + + decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. + decoder_labels (torch.LongTensor): labels for the captioning loss. + + itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). + itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) + + """ + + # uni-modal features + image_embeds: torch.FloatTensor = None + text_embeds: Optional[torch.FloatTensor] = None + + image_embeds_m: Optional[torch.FloatTensor] = None + text_embeds_m: Optional[torch.FloatTensor] = None + + # intermediate outputs of multimodal encoder + encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + + itm_logits: Optional[torch.FloatTensor] = None + itm_labels: Optional[torch.LongTensor] = None + + # intermediate outputs of multimodal decoder + decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None + decoder_labels: Optional[torch.LongTensor] = None + + +@dataclass +class BlipOutput(ModelOutput): + # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. + sims: Optional[BlipSimilarity] = None + + intermediate_output: BlipIntermediateOutput = None + + loss: Optional[torch.FloatTensor] = None + + loss_itc: Optional[torch.FloatTensor] = None + + loss_itm: Optional[torch.FloatTensor] = None + + loss_lm: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipOutputFeatures(ModelOutput): + """ + Data class of features from BlipFeatureExtractor. + + Args: + image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional + image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional + text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional + text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional + + The first embedding or feature is for the [CLS] token. + + Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. + """ + + image_embeds: Optional[torch.FloatTensor] = None + image_embeds_proj: Optional[torch.FloatTensor] = None + + text_embeds: Optional[torch.FloatTensor] = None + text_embeds_proj: Optional[torch.FloatTensor] = None + + multimodal_embeds: Optional[torch.FloatTensor] = None diff --git a/video_llama/models/eva_vit.py b/video_llama/models/eva_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..864bffd0c2ffad18c642ce55e9d0ccf44fbe5a56 --- /dev/null +++ b/video_llama/models/eva_vit.py @@ -0,0 +1,442 @@ +# Based on EVA, BEIT, timm and DeiT code bases +# https://github.com/baaivision/EVA +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + +from video_llama.common.dist_utils import download_cached_file + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + **kwargs + } + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + use_mean_pooling=True, init_scale=0.001, use_checkpoint=False): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + self.use_checkpoint = use_checkpoint + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) + for i in range(depth)]) +# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) +# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None +# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) +# if isinstance(self.head, nn.Linear): +# trunc_normal_(self.head.weight, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() +# if isinstance(self.head, nn.Linear): +# self.head.weight.data.mul_(init_scale) +# self.head.bias.data.mul_(init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, rel_pos_bias) + else: + x = blk(x, rel_pos_bias) + return x +# x = self.norm(x) + +# if self.fc_norm is not None: +# t = x[:, 1:, :] +# return self.fc_norm(t.mean(1)) +# else: +# return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) +# x = self.head(x) + return x + + def get_intermediate_layers(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + features.append(x) + + return features + + +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'].float() + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + +# if isinstance(l, (nn.MultiheadAttention, Attention)): +# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: +# tensor = getattr(l, attr) +# if tensor is not None: +# tensor.data = tensor.data.half() + + model.apply(_convert_weights_to_fp16) + + +def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"): + model = VisionTransformer( + img_size=img_size, + patch_size=14, + use_mean_pooling=False, + embed_dim=1408, + depth=39, + num_heads=1408//88, + mlp_ratio=4.3637, + qkv_bias=True, + drop_path_rate=drop_path_rate, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + use_checkpoint=use_checkpoint, + ) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" + cached_file = download_cached_file( + url, check_hash=False, progress=True + ) + state_dict = torch.load(cached_file, map_location="cpu") + interpolate_pos_embed(model,state_dict) + + incompatible_keys = model.load_state_dict(state_dict, strict=False) +# print(incompatible_keys) + + if precision == "fp16": +# model.to("cuda") + convert_weights_to_fp16(model) + return model \ No newline at end of file diff --git a/video_llama/models/modeling_llama.py b/video_llama/models/modeling_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..12d980e189d902fb1a6d9ea05dc3ca91959b1c8c --- /dev/null +++ b/video_llama/models/modeling_llama.py @@ -0,0 +1,755 @@ +# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + +""" PyTorch LLaMA model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.models.llama.configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if query_embeds is not None: + inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1) + batch_size, seq_length, _ = inputs_embeds.shape + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + query_embeds = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "query_embeds": query_embeds, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + diff --git a/video_llama/models/video_llama.py b/video_llama/models/video_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..c287887992ab48fdb7306cba2f6703e6b081712c --- /dev/null +++ b/video_llama/models/video_llama.py @@ -0,0 +1,424 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from video_llama.common.registry import registry +from video_llama.models.blip2 import Blip2Base, disabled_train +from video_llama.models.modeling_llama import LlamaForCausalLM +# from video_llama.models.Qformer import BertEncoder +from transformers import LlamaTokenizer,BertConfig +# from transformers.models.bert.modeling_bert import BertEncoder +import einops +import copy +import os +from video_llama.models.Qformer import BertConfig, BertLMHeadModel +# from flamingo_pytorch import PerceiverResampler +@registry.register_model("video_llama") +class VideoLLAMA(Blip2Base): + """ + BLIP2 GPT-LLAMA model. + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna": "configs/models/video_llama.yaml", + } + + @classmethod + def init_video_Qformer(cls, num_query_token, vision_width,num_hidden_layers =2): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.num_hidden_layers = num_hidden_layers + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = 1 + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + + frozen_llama_proj=True, + llama_proj_model='', + fusion_header_type= "seqTransf", + max_frame_pos= 32, + fusion_head_layers = 2, + num_video_query_token = 32, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + self.low_resource = low_resource + + print('Loading VIT') + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + if freeze_vit: + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + self.ln_vision.train = disabled_train + logging.info("freeze vision encoder") + print('Loading VIT Done') + + print('Loading Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.load_from_pretrained(url_or_filename=q_former_model) + + if freeze_qformer: + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + self.query_tokens.requires_grad = False + logging.info("freeze Qformer") + logging.info('Loading Q-Former Done') + + logging.info('Loading LLAMA Tokenizer') + self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False, use_auth_token=os.environ["API_TOKEN"]) + if self.llama_tokenizer.pad_token is None: + self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token + DEFAULT_IMAGE_PATCH_TOKEN = '' + self.llama_tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.IMAGE_PATCH_TOKEN_ID = self.llama_tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN] + + logging.info('Loading LLAMA Model') + if self.low_resource: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + load_in_8bit=True, + device_map={'': device_8bit}, + use_auth_token=os.environ["API_TOKEN"] + ) + else: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16,use_auth_token=os.environ["API_TOKEN"] + ) + + for name, param in self.llama_model.named_parameters(): + param.requires_grad = False + logging.info('Loading LLAMA Done') + + + logging.info('Loading LLAMA proj') + self.llama_proj = nn.Linear( + self.Qformer.config.hidden_size, self.llama_model.config.hidden_size + ) + if llama_proj_model: + print("load llama proj weight: {}".format(llama_proj_model)) + llama_proj_weight = torch.load(llama_proj_model, map_location="cpu") + msg = model.load_state_dict(llama_proj_weight['model'], strict=False) + + if frozen_llama_proj: + # todo frozen llama_proj + for name, param in self.llama_proj.named_parameters(): + param.requires_grad = False + logging.info('LLAMA proj is frozen') + else: + for name, param in self.llama_proj.named_parameters(): + param.requires_grad = True + logging.info('LLAMA proj is not frozen') + + logging.info('Loading llama_proj Done') + + self.max_txt_len = max_txt_len + self.end_sym = end_sym + + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] + self.prompt_list = [prompt_template.format(p) for p in filted_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + + self.video_frame_position_embedding = nn.Embedding(max_frame_pos, self.Qformer.config.hidden_size) + self.num_video_query_token = num_video_query_token + self.video_Qformer,self.video_query_tokens = self.init_video_Qformer(num_query_token = num_video_query_token,\ + vision_width=self.Qformer.config.hidden_size, num_hidden_layers =2) + + self.video_Qformer.cls = None + self.video_Qformer.bert.embeddings.word_embeddings = None + self.video_Qformer.bert.embeddings.position_embeddings = None + for layer in self.video_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + + def vit_to_cpu(self): + self.ln_vision.to("cpu") + self.ln_vision.float() + self.visual_encoder.to("cpu") + self.visual_encoder.float() + + def encode_img(self, image): + device = image.device + # if self.low_resource: + # self.vit_to_cpu() + # image = image.to("cpu") + + # input shape b,c,t,h,w + batch_size,_,time_length,_,_ = image.size() + image = einops.rearrange(image, 'b c t h w -> (b t) c h w') + with self.maybe_autocast(): + # embed image features with blip2, out: (b t) q h + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # add frame_pos embedding + position_ids = torch.arange(time_length, dtype=torch.long, device=query_tokens.device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + frame_position_embeddings = self.video_frame_position_embedding(position_ids) + q_hidden_state = query_output.last_hidden_state + + frame_position_embeddings = frame_position_embeddings.unsqueeze(-2) + frame_hidden_state = einops.rearrange(q_hidden_state, '(b t) q h -> b t q h',b=batch_size,t=time_length) + frame_hidden_state = frame_position_embeddings + frame_hidden_state + + # frame attention + frame_hidden_state = einops.rearrange(frame_hidden_state, 'b t q h -> b (t q) h',b=batch_size,t=time_length) + frame_atts = torch.ones(frame_hidden_state.size()[:-1], dtype=torch.long).to(device) + video_query_tokens = self.video_query_tokens.expand(frame_hidden_state.shape[0], -1, -1) + + # print('attention') + # print(video_query_tokens.size()) + # print(frame_hidden_state.size()) + video_query_output = self.video_Qformer.bert( + query_embeds=video_query_tokens, + encoder_hidden_states=frame_hidden_state, + encoder_attention_mask=frame_atts, + return_dict=True, + ) + video_hidden = video_query_output.last_hidden_state + + inputs_llama = self.llama_proj(video_hidden) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image_embeds.device) + return inputs_llama, atts_llama + + def prompt_wrap(self, img_embeds, atts_img, prompt): + if prompt: + batch_size = img_embeds.shape[0] + # print(prompt) + p_before, p_after = prompt.split('') + p_before_tokens = self.llama_tokenizer( + p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_after_tokens = self.llama_tokenizer( + p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) + p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) + wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1) + wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1]) + + return wrapped_img_embeds, wrapped_atts_img + else: + return img_embeds, atts_img + + def forward(self, samples): + if 'conv_type' in samples.keys() and samples['conv_type']=='multi': + num_patch_tokens = self.num_video_query_token + im_patch_token_id = self.IMAGE_PATCH_TOKEN_ID + image = samples["images"] + input_ids = samples['input_ids'] + if len(image.size())==4: + time = 1 + image = einops.repeat(image, 'b c h w -> b c t h w',t = time) + img_embeds, atts_img = self.encode_img(image) + + temp_input_ids = copy.deepcopy(input_ids) + temp_input_ids[temp_input_ids == im_patch_token_id] = 0 + temp_input_embedding = self.llama_model.model.embed_tokens(temp_input_ids) + + new_input_embeds=[] + cur_image_idx = 0 + for cur_input_ids, cur_input_embeds in zip(input_ids, temp_input_embedding): + cur_image_features = img_embeds[cur_image_idx] + + if (cur_input_ids == im_patch_token_id).sum() != num_patch_tokens: + raise ValueError("The number of image patch tokens should be the same as the number of image patches.") + masked_indices = torch.where(cur_input_ids == im_patch_token_id)[0] + mask_index_start = masked_indices[0] + if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patch_tokens, device=masked_indices.device, dtype=masked_indices.dtype)).any(): + raise ValueError("The image patch tokens should be consecutive.") + + cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patch_tokens:]), dim=0) + new_input_embeds.append(cur_new_input_embeds) + + cur_image_idx+=1 + inputs_embeds = torch.stack(new_input_embeds, dim=0) + targets = samples['labels'] + attention_mask = samples['attention_mask'] + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + return {"loss": loss} + else: + image = samples["image"] + + if len(image.size()) != 5: + time = 1 + image = einops.repeat(image, 'b c h w -> b c t h w',t = time) + + img_embeds, atts_img = self.encode_img(image) + + if self.prompt_list: + prompt = random.choice(self.prompt_list) + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt) + + + self.llama_tokenizer.padding_side = "right" + + text = [t + self.end_sym for t in samples["text_input"]] + + to_regress_tokens = self.llama_tokenizer( + text, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False + ).to(image.device) + + targets = to_regress_tokens.input_ids.masked_fill( + to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 + ) + + empty_targets = ( + torch.ones([atts_img.shape[0], atts_img.shape[1]+1], + dtype=torch.long).to(image.device).fill_(-100) # plus one for bos + ) + targets = torch.cat([empty_targets, targets], dim=1) + + batch_size = img_embeds.shape[0] + bos = torch.ones([batch_size, 1], + dtype=to_regress_tokens.input_ids.dtype, + device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + atts_bos = atts_img[:, :1] + + to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) + inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) + attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) + + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + return {"loss": loss} + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + device_8bit = cfg.get("device_8bit", 0) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 32) + end_sym = cfg.get("end_sym", '\n') + + frozen_llama_proj = cfg.get("frozen_llama_proj", True) + llama_proj_model = cfg.get("llama_proj_model", '') + + fusion_header_type = cfg.get("fusion_header_type", 'seqTransf') + max_frame_pos = cfg.get("max_frame_pos", 32) + fusion_head_layers = cfg.get("fusion_head_layers", 2) + num_video_query_token = cfg.get("num_video_query_token", 32) + + model = cls( + vit_model=vit_model, + q_former_model=q_former_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + freeze_qformer=freeze_qformer, + num_query_token=num_query_token, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + fusion_header_type=fusion_header_type, + max_frame_pos=max_frame_pos, + fusion_head_layers=fusion_head_layers, + frozen_llama_proj=frozen_llama_proj, + num_video_query_token=num_video_query_token + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + return model diff --git a/video_llama/processors/__init__.py b/video_llama/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..169237f3dd45dba53cf77f40c8a69e835d0bcecc --- /dev/null +++ b/video_llama/processors/__init__.py @@ -0,0 +1,38 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from video_llama.processors.base_processor import BaseProcessor +from video_llama.processors.blip_processors import ( + Blip2ImageTrainProcessor, + Blip2ImageEvalProcessor, + BlipCaptionProcessor, +) +from video_llama.processors.video_processor import ( + AlproVideoTrainProcessor, + AlproVideoEvalProcessor +) +from video_llama.common.registry import registry + +__all__ = [ + "BaseProcessor", + "Blip2ImageTrainProcessor", + "Blip2ImageEvalProcessor", + "BlipCaptionProcessor", + "AlproVideoTrainProcessor", + "AlproVideoEvalProcessor", +] + + +def load_processor(name, cfg=None): + """ + Example + + >>> processor = load_processor("alpro_video_train", cfg=None) + """ + processor = registry.get_processor_class(name).from_config(cfg) + + return processor diff --git a/video_llama/processors/base_processor.py b/video_llama/processors/base_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..39b33cdf8fcd97cfd3e4a5fbece6593357af9d41 --- /dev/null +++ b/video_llama/processors/base_processor.py @@ -0,0 +1,26 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from omegaconf import OmegaConf + + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + return cls() + + def build(self, **kwargs): + cfg = OmegaConf.create(kwargs) + + return self.from_config(cfg) diff --git a/video_llama/processors/blip_processors.py b/video_llama/processors/blip_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..6e603b638607921440bf7c4fcf22c5f1aeb7f20d --- /dev/null +++ b/video_llama/processors/blip_processors.py @@ -0,0 +1,142 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import re + +from video_llama.common.registry import registry +from video_llama.processors.base_processor import BaseProcessor +from video_llama.processors.randaugment import RandomAugment +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + + +class BlipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean, std) + + +@registry.register_processor("blip_caption") +class BlipCaptionProcessor(BaseProcessor): + def __init__(self, prompt="", max_words=50): + self.prompt = prompt + self.max_words = max_words + + def __call__(self, caption): + caption = self.prompt + self.pre_caption(caption) + + return caption + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + prompt = cfg.get("prompt", "") + max_words = cfg.get("max_words", 50) + + return cls(prompt=prompt, max_words=max_words) + + def pre_caption(self, caption): + caption = re.sub( + r"([.!\"()*#:;~])", + " ", + caption.lower(), + ) + caption = re.sub( + r"\s{2,}", + " ", + caption, + ) + caption = caption.rstrip("\n") + caption = caption.strip(" ") + + # truncate caption + caption_words = caption.split(" ") + if len(caption_words) > self.max_words: + caption = " ".join(caption_words[: self.max_words]) + + return caption + + +@registry.register_processor("blip2_image_train") +class Blip2ImageTrainProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + image_size, + scale=(min_scale, max_scale), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + ) + + +@registry.register_processor("blip2_image_eval") +class Blip2ImageEvalProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls(image_size=image_size, mean=mean, std=std) + diff --git a/video_llama/processors/functional_video.py b/video_llama/processors/functional_video.py new file mode 100644 index 0000000000000000000000000000000000000000..597a29315d4e1a575e7209edb0618eeaf4fc024a --- /dev/null +++ b/video_llama/processors/functional_video.py @@ -0,0 +1,121 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import warnings + +import torch + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i : i + h, j : j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError( + f"target size should be tuple (height, width), instead got {target_size}" + ) + return torch.nn.functional.interpolate( + clip, size=target_size, mode=interpolation_mode, align_corners=False + ) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) + Return: + clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError( + "clip tensor should have data type uint8. Got %s" % str(clip.dtype) + ) + return clip.float().permute(3, 0, 1, 2) / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (C, T, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) + Returns: + flipped clip (torch.tensor): Size is (C, T, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) diff --git a/video_llama/processors/randaugment.py b/video_llama/processors/randaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..7034a49ad5fc63b97910790017432617ff4c6d7b --- /dev/null +++ b/video_llama/processors/randaugment.py @@ -0,0 +1,398 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import cv2 +import numpy as np + +import torch + + +## aug functions +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + """ + same output as PIL.ImageOps.autocontrast + """ + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if low.shape[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + """ + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + """ + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: + return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + """ + like PIL, rotate by degree, not radians + """ + H, W = img.shape[0], img.shape[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + """ + same output as PIL.ImageOps.posterize + """ + table = np.array([el if el < thresh else 255 - el for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + """ + same output as PIL.ImageEnhance.Color + """ + ## implementation according to PIL definition, quite slow + # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] + # out = blend(degenerate, img, factor) + # M = ( + # np.eye(3) * factor + # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) + # )[np.newaxis, np.newaxis, :] + M = np.float32( + [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] + ) * factor + np.float32([[0.114], [0.587], [0.299]]) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = ( + np.array([(el - mean) * factor + mean for el in range(256)]) + .clip(0, 255) + .astype(np.uint8) + ) + out = table[img] + return out + + +def brightness_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + """ + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + """ + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def posterize_func(img, bits): + """ + same output as PIL.ImageOps.posterize + """ + out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = img.shape[0], img.shape[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +### level to args +def enhance_level_to_args(MAX_LEVEL): + def level_to_args(level): + return ((level / MAX_LEVEL) * 1.8 + 0.1,) + + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 0.3 + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * float(translate_const) + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = int((level / MAX_LEVEL) * cutout_const) + return (level, replace_value) + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 256) + return (level,) + + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 4) + return (level,) + + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 30 + if np.random.random() < 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +func_dict = { + "Identity": identity_func, + "AutoContrast": autocontrast_func, + "Equalize": equalize_func, + "Rotate": rotate_func, + "Solarize": solarize_func, + "Color": color_func, + "Contrast": contrast_func, + "Brightness": brightness_func, + "Sharpness": sharpness_func, + "ShearX": shear_x_func, + "TranslateX": translate_x_func, + "TranslateY": translate_y_func, + "Posterize": posterize_func, + "ShearY": shear_y_func, +} + +translate_const = 10 +MAX_LEVEL = 10 +replace_value = (128, 128, 128) +arg_dict = { + "Identity": none_level_to_args, + "AutoContrast": none_level_to_args, + "Equalize": none_level_to_args, + "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), + "Solarize": solarize_level_to_args(MAX_LEVEL), + "Color": enhance_level_to_args(MAX_LEVEL), + "Contrast": enhance_level_to_args(MAX_LEVEL), + "Brightness": enhance_level_to_args(MAX_LEVEL), + "Sharpness": enhance_level_to_args(MAX_LEVEL), + "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), + "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "Posterize": posterize_level_to_args(MAX_LEVEL), + "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img + + +class VideoRandomAugment(object): + def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): + self.N = N + self.M = M + self.p = p + self.tensor_in_tensor_out = tensor_in_tensor_out + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N, replace=False) + return [(op, self.M) for op in sampled_ops] + + def __call__(self, frames): + assert ( + frames.shape[-1] == 3 + ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." + + if self.tensor_in_tensor_out: + frames = frames.numpy().astype(np.uint8) + + num_frames = frames.shape[0] + + ops = num_frames * [self.get_random_ops()] + apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] + + frames = torch.stack( + list(map(self._aug, frames, ops, apply_or_not)), dim=0 + ).float() + + return frames + + def _aug(self, img, ops, apply_or_not): + for i, (name, level) in enumerate(ops): + if not apply_or_not[i]: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return torch.from_numpy(img) + + +if __name__ == "__main__": + a = RandomAugment() + img = np.random.randn(32, 32, 3) + a(img) diff --git a/video_llama/processors/transforms_video.py b/video_llama/processors/transforms_video.py new file mode 100644 index 0000000000000000000000000000000000000000..1106f388c4091f919e0e9602fcb1363e9caec9a6 --- /dev/null +++ b/video_llama/processors/transforms_video.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + + +import numbers +import random + +from torchvision.transforms import ( + RandomCrop, + RandomResizedCrop, +) + +import video_llama.processors.functional_video as F + + +__all__ = [ + "RandomCropVideo", + "RandomResizedCropVideo", + "CenterCropVideo", + "NormalizeVideo", + "ToTensorVideo", + "RandomHorizontalFlipVideo", +] + + +class RandomCropVideo(RandomCrop): + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: randomly cropped/resized video clip. + size is (C, T, OH, OW) + """ + i, j, h, w = self.get_params(clip, self.size) + return F.crop(clip, i, j, h, w) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class RandomResizedCropVideo(RandomResizedCrop): + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError( + f"size should be tuple (height, width), instead got {size}" + ) + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + self.scale = scale + self.ratio = ratio + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: randomly cropped/resized video clip. + size is (C, T, H, W) + """ + i, j, h, w = self.get_params(clip, self.scale, self.ratio) + return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})" + + +class CenterCropVideo: + def __init__(self, crop_size): + if isinstance(crop_size, numbers.Number): + self.crop_size = (int(crop_size), int(crop_size)) + else: + self.crop_size = crop_size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: central cropping of video clip. Size is + (C, T, crop_size, crop_size) + """ + return F.center_crop(clip, self.crop_size) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(crop_size={self.crop_size})" + + +class NormalizeVideo: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W) + """ + return F.normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) + Return: + clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) + """ + return F.to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class RandomHorizontalFlipVideo: + """ + Flip the video clip along the horizonal direction with a given probability + Args: + p (float): probability of the clip being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Size is (C, T, H, W) + Return: + clip (torch.tensor): Size is (C, T, H, W) + """ + if random.random() < self.p: + clip = F.hflip(clip) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" diff --git a/video_llama/processors/video_processor.py b/video_llama/processors/video_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..b272e318230c544748818476bbe4caa1fc9f847d --- /dev/null +++ b/video_llama/processors/video_processor.py @@ -0,0 +1,237 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch +from video_llama.common.registry import registry +from decord import VideoReader +import decord +import numpy as np +from video_llama.processors import transforms_video +from video_llama.processors.base_processor import BaseProcessor +from video_llama.processors.randaugment import VideoRandomAugment +from video_llama.processors import functional_video as F +from omegaconf import OmegaConf +from torchvision import transforms +import random as rnd + + +MAX_INT = registry.get("MAX_INT") +decord.bridge.set_bridge("torch") + +def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform", return_msg = False): + decord.bridge.set_bridge("torch") + vr = VideoReader(uri=video_path, height=height, width=width) + + vlen = len(vr) + start, end = 0, vlen + + n_frms = min(n_frms, vlen) + + if sampling == "uniform": + indices = np.arange(start, end, vlen / n_frms).astype(int).tolist() + elif sampling == "headtail": + indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2)) + indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2)) + indices = indices_h + indices_t + else: + raise NotImplementedError + + # get_batch -> T, H, W, C + temp_frms = vr.get_batch(indices) + # print(type(temp_frms)) + tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + frms = tensor_frms.permute(3, 0, 1, 2).float() # (C, T, H, W) + + if not return_msg: + return frms + + fps = float(vr.get_avg_fps()) + sec = ", ".join([str(round(f / fps, 1)) for f in indices]) + # " " should be added in the start and end + msg = f"The video contains {len(indices)} frames sampled at {sec} seconds. " + return frms, msg + + +class AlproVideoBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None, n_frms=MAX_INT): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms_video.NormalizeVideo(mean, std) + + self.n_frms = n_frms + + +class ToUint8(object): + def __init__(self): + pass + + def __call__(self, tensor): + return tensor.to(torch.uint8) + + def __repr__(self): + return self.__class__.__name__ + + +class ToTHWC(object): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C) + """ + + def __init__(self): + pass + + def __call__(self, tensor): + return tensor.permute(1, 2, 3, 0) + + def __repr__(self): + return self.__class__.__name__ + + +class ResizeVideo(object): + def __init__(self, target_size, interpolation_mode="bilinear"): + self.target_size = target_size + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: central cropping of video clip. Size is + (C, T, crop_size, crop_size) + """ + return F.resize(clip, self.target_size, self.interpolation_mode) + + def __repr__(self): + return self.__class__.__name__ + "(resize_size={0})".format(self.target_size) + + +@registry.register_processor("alpro_video_train") +class AlproVideoTrainProcessor(AlproVideoBaseProcessor): + def __init__( + self, + image_size=384, + mean=None, + std=None, + min_scale=0.5, + max_scale=1.0, + n_frms=MAX_INT, + ): + super().__init__(mean=mean, std=std, n_frms=n_frms) + + self.image_size = image_size + + self.transform = transforms.Compose( + [ + # Video size is (C, T, H, W) + transforms_video.RandomResizedCropVideo( + image_size, + scale=(min_scale, max_scale), + interpolation_mode="bicubic", + ), + ToTHWC(), # C, T, H, W -> T, H, W, C + ToUint8(), + transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W + self.normalize, + ] + ) + + def __call__(self, vpath): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: video clip after transforms. Size is (C, T, size, size). + """ + clip = load_video( + video_path=vpath, + n_frms=self.n_frms, + height=self.image_size, + width=self.image_size, + sampling="headtail", + ) + + return self.transform(clip) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 256) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + n_frms = cfg.get("n_frms", MAX_INT) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + n_frms=n_frms, + ) + + +@registry.register_processor("alpro_video_eval") +class AlproVideoEvalProcessor(AlproVideoBaseProcessor): + def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT): + super().__init__(mean=mean, std=std, n_frms=n_frms) + + self.image_size = image_size + + # Input video size is (C, T, H, W) + self.transform = transforms.Compose( + [ + # frames will be resized during decord loading. + ToUint8(), # C, T, H, W + ToTHWC(), # T, H, W, C + transforms_video.ToTensorVideo(), # C, T, H, W + self.normalize, # C, T, H, W + ] + ) + + def __call__(self, vpath): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: video clip after transforms. Size is (C, T, size, size). + """ + clip = load_video( + video_path=vpath, + n_frms=self.n_frms, + height=self.image_size, + width=self.image_size, + ) + + return self.transform(clip) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 256) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + n_frms = cfg.get("n_frms", MAX_INT) + + return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms) diff --git a/video_llama/runners/__init__.py b/video_llama/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffe5b0b10e013fb6d69eb6879b1e42c06d5b447 --- /dev/null +++ b/video_llama/runners/__init__.py @@ -0,0 +1,10 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from video_llama.runners.runner_base import RunnerBase + +__all__ = ["RunnerBase"] diff --git a/video_llama/runners/runner_base.py b/video_llama/runners/runner_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c944123917dd0bf9947f4204f9044538a0f8bf22 --- /dev/null +++ b/video_llama/runners/runner_base.py @@ -0,0 +1,658 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import json +import logging +import os +import time +from pathlib import Path + +import torch +import torch.distributed as dist +import webdataset as wds +from video_llama.common.dist_utils import ( + download_cached_file, + get_rank, + get_world_size, + is_main_process, + main_process, +) +from video_llama.common.registry import registry +from video_llama.common.utils import is_url +from video_llama.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset +from video_llama.datasets.datasets.dataloader_utils import ( + IterLoader, + MultiIterLoader, + PrefetchLoader, +) +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler + + +@registry.register_runner("runner_base") +class RunnerBase: + """ + A runner class to train and evaluate a model given a task and datasets. + + The runner uses pytorch distributed data parallel by default. Future release + will support other distributed frameworks. + """ + + def __init__(self, cfg, task, model, datasets, job_id): + self.config = cfg + self.job_id = job_id + + self.task = task + self.datasets = datasets + + self._model = model + + self._wrapped_model = None + self._device = None + self._optimizer = None + self._scaler = None + self._dataloaders = None + self._lr_sched = None + + self.start_epoch = 0 + + # self.setup_seeds() + self.setup_output_dir() + + @property + def device(self): + if self._device is None: + self._device = torch.device(self.config.run_cfg.device) + + return self._device + + @property + def use_distributed(self): + return self.config.run_cfg.distributed + + @property + def model(self): + """ + A property to get the DDP-wrapped model on the device. + """ + # move model to device + if self._model.device != self.device: + self._model = self._model.to(self.device) + + # distributed training wrapper + if self.use_distributed: + if self._wrapped_model is None: + self._wrapped_model = DDP( + self._model, device_ids=[self.config.run_cfg.gpu] + ) + else: + self._wrapped_model = self._model + + return self._wrapped_model + + @property + def optimizer(self): + # TODO make optimizer class and configurations + if self._optimizer is None: + num_parameters = 0 + p_wd, p_non_wd = [], [] + for n, p in self.model.named_parameters(): + if not p.requires_grad: + continue # frozen weights + print(n) + if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: + p_non_wd.append(p) + else: + p_wd.append(p) + num_parameters += p.data.nelement() + logging.info("number of trainable parameters: %d" % num_parameters) + optim_params = [ + { + "params": p_wd, + "weight_decay": float(self.config.run_cfg.weight_decay), + }, + {"params": p_non_wd, "weight_decay": 0}, + ] + beta2 = self.config.run_cfg.get("beta2", 0.999) + self._optimizer = torch.optim.AdamW( + optim_params, + lr=float(self.config.run_cfg.init_lr), + weight_decay=float(self.config.run_cfg.weight_decay), + betas=(0.9, beta2), + ) + + return self._optimizer + + @property + def scaler(self): + amp = self.config.run_cfg.get("amp", False) + + if amp: + if self._scaler is None: + self._scaler = torch.cuda.amp.GradScaler() + + return self._scaler + + @property + def lr_scheduler(self): + """ + A property to get and create learning rate scheduler by split just in need. + """ + if self._lr_sched is None: + lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched) + + # max_epoch = self.config.run_cfg.max_epoch + max_epoch = self.max_epoch + # min_lr = self.config.run_cfg.min_lr + min_lr = self.min_lr + # init_lr = self.config.run_cfg.init_lr + init_lr = self.init_lr + + # optional parameters + decay_rate = self.config.run_cfg.get("lr_decay_rate", None) + warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1) + warmup_steps = self.config.run_cfg.get("warmup_steps", 0) + iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None) + + if iters_per_epoch is None: + try: + iters_per_epoch = len(self.dataloaders['train']) + except (AttributeError, TypeError): + iters_per_epoch = 10000 + + self._lr_sched = lr_sched_cls( + optimizer=self.optimizer, + max_epoch=max_epoch, + iters_per_epoch=iters_per_epoch, + min_lr=min_lr, + init_lr=init_lr, + decay_rate=decay_rate, + warmup_start_lr=warmup_start_lr, + warmup_steps=warmup_steps, + ) + + return self._lr_sched + + @property + def dataloaders(self) -> dict: + """ + A property to get and create dataloaders by split just in need. + + If no train_dataset_ratio is provided, concatenate map-style datasets and + chain wds.DataPipe datasets separately. Training set becomes a tuple + (ConcatDataset, ChainDataset), both are optional but at least one of them is + required. The resultant ConcatDataset and ChainDataset will be sampled evenly. + + If train_dataset_ratio is provided, create a MultiIterLoader to sample + each dataset by ratios during training. + + Currently do not support multiple datasets for validation and test. + + Returns: + dict: {split_name: (tuples of) dataloader} + """ + if self._dataloaders is None: + + # concatenate map-style datasets and chain wds.DataPipe datasets separately + # training set becomes a tuple (ConcatDataset, ChainDataset), both are + # optional but at least one of them is required. The resultant ConcatDataset + # and ChainDataset will be sampled evenly. + logging.info( + "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." + ) + + datasets = reorg_datasets_by_split(self.datasets) + self.datasets = datasets + # self.datasets = concat_datasets(datasets) + + # print dataset statistics after concatenation/chaining + for split_name in self.datasets: + if isinstance(self.datasets[split_name], tuple) or isinstance( + self.datasets[split_name], list + ): + # mixed wds.DataPipeline and torch.utils.data.Dataset + num_records = sum( + [ + len(d) + if not type(d) in [wds.DataPipeline, ChainDataset] + else 0 + for d in self.datasets[split_name] + ] + ) + + else: + if hasattr(self.datasets[split_name], "__len__"): + # a single map-style dataset + num_records = len(self.datasets[split_name]) + else: + # a single wds.DataPipeline + num_records = -1 + logging.info( + "Only a single wds.DataPipeline dataset, no __len__ attribute." + ) + + if num_records >= 0: + logging.info( + "Loaded {} records for {} split from the dataset.".format( + num_records, split_name + ) + ) + + # create dataloaders + split_names = sorted(self.datasets.keys()) + + datasets = [self.datasets[split] for split in split_names] + is_trains = [split in self.train_splits for split in split_names] + + batch_sizes = [ + self.config.run_cfg.batch_size_train + if split == "train" + else self.config.run_cfg.batch_size_eval + for split in split_names + ] + + collate_fns = [] + for dataset in datasets: + if isinstance(dataset, tuple) or isinstance(dataset, list): + collate_fns.append([getattr(d, "collater", None) for d in dataset]) + else: + collate_fns.append(getattr(dataset, "collater", None)) + + dataloaders = self.create_loaders( + datasets=datasets, + num_workers=self.config.run_cfg.num_workers, + batch_sizes=batch_sizes, + is_trains=is_trains, + collate_fns=collate_fns, + ) + + self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} + + return self._dataloaders + + @property + def cuda_enabled(self): + return self.device.type == "cuda" + + @property + def max_epoch(self): + return int(self.config.run_cfg.max_epoch) + + @property + def log_freq(self): + log_freq = self.config.run_cfg.get("log_freq", 50) + return int(log_freq) + + @property + def init_lr(self): + return float(self.config.run_cfg.init_lr) + + @property + def min_lr(self): + return float(self.config.run_cfg.min_lr) + + @property + def accum_grad_iters(self): + return int(self.config.run_cfg.get("accum_grad_iters", 1)) + + @property + def valid_splits(self): + valid_splits = self.config.run_cfg.get("valid_splits", []) + + if len(valid_splits) == 0: + logging.info("No validation splits found.") + + return valid_splits + + @property + def test_splits(self): + test_splits = self.config.run_cfg.get("test_splits", []) + + return test_splits + + @property + def train_splits(self): + train_splits = self.config.run_cfg.get("train_splits", []) + + if len(train_splits) == 0: + logging.info("Empty train splits.") + + return train_splits + + @property + def evaluate_only(self): + """ + Set to True to skip training. + """ + return self.config.run_cfg.evaluate + + @property + def use_dist_eval_sampler(self): + return self.config.run_cfg.get("use_dist_eval_sampler", True) + + @property + def resume_ckpt_path(self): + return self.config.run_cfg.get("resume_ckpt_path", None) + + @property + def train_loader(self): + train_dataloader = self.dataloaders["train"] + + return train_dataloader + + def setup_output_dir(self): + lib_root = Path(registry.get_path("library_root")) + + output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id + result_dir = output_dir / "result" + + output_dir.mkdir(parents=True, exist_ok=True) + result_dir.mkdir(parents=True, exist_ok=True) + + registry.register_path("result_dir", str(result_dir)) + registry.register_path("output_dir", str(output_dir)) + + self.result_dir = result_dir + self.output_dir = output_dir + + def train(self): + start_time = time.time() + best_agg_metric = 0 + best_epoch = 0 + + self.log_config() + + # resume from checkpoint if specified + if not self.evaluate_only and self.resume_ckpt_path is not None: + self._load_checkpoint(self.resume_ckpt_path) + + for cur_epoch in range(self.start_epoch, self.max_epoch): + # training phase + if not self.evaluate_only: + logging.info("Start training") + train_stats = self.train_epoch(cur_epoch) + self.log_stats(split_name="train", stats=train_stats) + + # evaluation phase + if len(self.valid_splits) > 0: + for split_name in self.valid_splits: + logging.info("Evaluating on {}.".format(split_name)) + + val_log = self.eval_epoch( + split_name=split_name, cur_epoch=cur_epoch + ) + if val_log is not None: + if is_main_process(): + assert ( + "agg_metrics" in val_log + ), "No agg_metrics found in validation log." + + agg_metrics = val_log["agg_metrics"] + if agg_metrics > best_agg_metric and split_name == "val": + best_epoch, best_agg_metric = cur_epoch, agg_metrics + + self._save_checkpoint(cur_epoch, is_best=True) + + val_log.update({"best_epoch": best_epoch}) + self.log_stats(val_log, split_name) + + else: + # if no validation split is provided, we just save the checkpoint at the end of each epoch. + if not self.evaluate_only: + self._save_checkpoint(cur_epoch, is_best=False) + + if self.evaluate_only: + break + + if self.config.run_cfg.distributed: + dist.barrier() + + # testing phase + test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch + self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Training time {}".format(total_time_str)) + + def evaluate(self, cur_epoch="best", skip_reload=False): + test_logs = dict() + + if len(self.test_splits) > 0: + for split_name in self.test_splits: + test_logs[split_name] = self.eval_epoch( + split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload + ) + + return test_logs + + def train_epoch(self, epoch): + # train + self.model.train() + + return self.task.train_epoch( + epoch=epoch, + model=self.model, + data_loader=self.train_loader, + optimizer=self.optimizer, + scaler=self.scaler, + lr_scheduler=self.lr_scheduler, + cuda_enabled=self.cuda_enabled, + log_freq=self.log_freq, + accum_grad_iters=self.accum_grad_iters, + ) + + @torch.no_grad() + def eval_epoch(self, split_name, cur_epoch, skip_reload=False): + """ + Evaluate the model on a given split. + + Args: + split_name (str): name of the split to evaluate on. + cur_epoch (int): current epoch. + skip_reload_best (bool): whether to skip reloading the best checkpoint. + During training, we will reload the best checkpoint for validation. + During testing, we will use provided weights and skip reloading the best checkpoint . + """ + data_loader = self.dataloaders.get(split_name, None) + assert data_loader, "data_loader for split {} is None.".format(split_name) + + # TODO In validation, you need to compute loss as well as metrics + # TODO consider moving to model.before_evaluation() + model = self.unwrap_dist_model(self.model) + if not skip_reload and cur_epoch == "best": + model = self._reload_best_model(model) + model.eval() + + self.task.before_evaluation( + model=model, + dataset=self.datasets[split_name], + ) + results = self.task.evaluation(model, data_loader) + + if results is not None: + return self.task.after_evaluation( + val_result=results, + split_name=split_name, + epoch=cur_epoch, + ) + + def unwrap_dist_model(self, model): + if self.use_distributed: + return model.module + else: + return model + + def create_loaders( + self, + datasets, + num_workers, + batch_sizes, + is_trains, + collate_fns, + dataset_ratios=None, + ): + """ + Create dataloaders for training and validation. + """ + + def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): + # create a single dataloader for each split + if isinstance(dataset, ChainDataset) or isinstance( + dataset, wds.DataPipeline + ): + # wds.WebdDataset instance are chained together + # webdataset.DataPipeline has its own sampler and collate_fn + loader = iter( + DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + ) + ) + else: + # map-style dataset are concatenated together + # setup distributed sampler + if self.use_distributed: + sampler = DistributedSampler( + dataset, + shuffle=is_train, + num_replicas=get_world_size(), + rank=get_rank(), + ) + if not self.use_dist_eval_sampler: + # e.g. retrieval evaluation + sampler = sampler if is_train else None + else: + sampler = None + + loader = DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + sampler=sampler, + shuffle=sampler is None and is_train, + collate_fn=collate_fn, + drop_last=True if is_train else False, + ) + loader = PrefetchLoader(loader) + + if is_train: + loader = IterLoader(loader, use_distributed=self.use_distributed) + + return loader + + loaders = [] + + for dataset, bsz, is_train, collate_fn in zip( + datasets, batch_sizes, is_trains, collate_fns + ): + if isinstance(dataset, list) or isinstance(dataset, tuple): + if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None: + dataset_ratios = [d.sample_ratio for d in dataset] + loader = MultiIterLoader( + loaders=[ + _create_loader(d, num_workers, bsz, is_train, collate_fn[i]) + for i, d in enumerate(dataset) + ], + ratios=dataset_ratios, + ) + else: + loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn) + + loaders.append(loader) + + return loaders + + @main_process + def _save_checkpoint(self, cur_epoch, is_best=False): + """ + Save the checkpoint at the current epoch. + """ + model_no_ddp = self.unwrap_dist_model(self.model) + param_grad_dic = { + k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() + } + state_dict = model_no_ddp.state_dict() + for k in list(state_dict.keys()): + if k in param_grad_dic.keys() and not param_grad_dic[k]: + # delete parameters that do not require gradient + del state_dict[k] + save_obj = { + "model": state_dict, + "optimizer": self.optimizer.state_dict(), + "config": self.config.to_dict(), + "scaler": self.scaler.state_dict() if self.scaler else None, + "epoch": cur_epoch, + } + save_to = os.path.join( + self.output_dir, + "checkpoint_{}.pth".format("best" if is_best else cur_epoch), + ) + logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to)) + torch.save(save_obj, save_to) + + def _reload_best_model(self, model): + """ + Load the best checkpoint for evaluation. + """ + checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth") + + logging.info("Loading checkpoint from {}.".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + try: + model.load_state_dict(checkpoint["model"]) + except RuntimeError as e: + logging.warning( + """ + Key mismatch when loading checkpoint. This is expected if only part of the model is saved. + Trying to load the model with strict=False. + """ + ) + model.load_state_dict(checkpoint["model"], strict=False) + return model + + def _load_checkpoint(self, url_or_filename): + """ + Resume from a checkpoint. + """ + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location=self.device, strict=False) + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False) + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + self.unwrap_dist_model(self.model).load_state_dict(state_dict) + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if self.scaler and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.start_epoch = checkpoint["epoch"] + 1 + logging.info("Resume checkpoint from {}".format(url_or_filename)) + + @main_process + def log_stats(self, stats, split_name): + if isinstance(stats, dict): + log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}} + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(log_stats) + "\n") + elif isinstance(stats, list): + pass + + @main_process + def log_config(self): + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(self.config.to_dict(), indent=4) + "\n") diff --git a/video_llama/runners/test.py b/video_llama/runners/test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/video_llama/tasks/__init__.py b/video_llama/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86c47d8aafe2637e13f3d837904a0f51dc96b379 --- /dev/null +++ b/video_llama/tasks/__init__.py @@ -0,0 +1,28 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from video_llama.common.registry import registry +from video_llama.tasks.base_task import BaseTask +from video_llama.tasks.image_text_pretrain import ImageTextPretrainTask +from video_llama.tasks.video_text_pretrain import VideoTextPretrainTask + + +def setup_task(cfg): + assert "task" in cfg.run_cfg, "Task name must be provided." + + task_name = cfg.run_cfg.task + task = registry.get_task_class(task_name).setup_task(cfg=cfg) + assert task is not None, "Task {} not properly registered.".format(task_name) + + return task + + +__all__ = [ + "BaseTask", + "ImageTextPretrainTask", + "VideoTextPretrainTask" +] diff --git a/video_llama/tasks/base_task.py b/video_llama/tasks/base_task.py new file mode 100644 index 0000000000000000000000000000000000000000..d91aa346df45ed58b7d2425a59dd21f3ae8ef93c --- /dev/null +++ b/video_llama/tasks/base_task.py @@ -0,0 +1,286 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import torch +import torch.distributed as dist +from video_llama.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized +from video_llama.common.logger import MetricLogger, SmoothedValue +from video_llama.common.registry import registry +from video_llama.datasets.data_utils import prepare_sample + + +class BaseTask: + def __init__(self, **kwargs): + super().__init__() + + self.inst_id_key = "instance_id" + + @classmethod + def setup_task(cls, **kwargs): + return cls() + + def build_model(self, cfg): + model_config = cfg.model_cfg + + model_cls = registry.get_model_class(model_config.arch) + return model_cls.from_config(model_config) + + def build_datasets(self, cfg): + """ + Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. + Download dataset and annotations automatically if not exist. + + Args: + cfg (common.config.Config): _description_ + + Returns: + dict: Dictionary of torch.utils.data.Dataset objects by split. + """ + + datasets = dict() + + datasets_config = cfg.datasets_cfg + + assert len(datasets_config) > 0, "At least one dataset has to be specified." + + for name in datasets_config: + dataset_config = datasets_config[name] + + builder = registry.get_builder_class(name)(dataset_config) + dataset = builder.build_datasets() + + dataset['train'].name = name + if 'sample_ratio' in dataset_config: + dataset['train'].sample_ratio = dataset_config.sample_ratio + + datasets[name] = dataset + + return datasets + + def train_step(self, model, samples): + loss = model(samples)["loss"] + return loss + + def valid_step(self, model, samples): + raise NotImplementedError + + def before_evaluation(self, model, dataset, **kwargs): + model.before_evaluation(dataset=dataset, task_type=type(self)) + + def after_evaluation(self, **kwargs): + pass + + def inference_step(self): + raise NotImplementedError + + def evaluation(self, model, data_loader, cuda_enabled=True): + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation" + # TODO make it configurable + print_freq = 10 + + results = [] + + for samples in metric_logger.log_every(data_loader, print_freq, header): + samples = prepare_sample(samples, cuda_enabled=cuda_enabled) + + eval_output = self.valid_step(model=model, samples=samples) + results.extend(eval_output) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return results + + def train_epoch( + self, + epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + cuda_enabled=False, + log_freq=50, + accum_grad_iters=1, + ): + return self._train_inner_loop( + epoch=epoch, + iters_per_epoch=lr_scheduler.iters_per_epoch, + model=model, + data_loader=data_loader, + optimizer=optimizer, + scaler=scaler, + lr_scheduler=lr_scheduler, + log_freq=log_freq, + cuda_enabled=cuda_enabled, + accum_grad_iters=accum_grad_iters, + ) + + def train_iters( + self, + epoch, + start_iters, + iters_per_inner_epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + cuda_enabled=False, + log_freq=50, + accum_grad_iters=1, + ): + return self._train_inner_loop( + epoch=epoch, + start_iters=start_iters, + iters_per_epoch=iters_per_inner_epoch, + model=model, + data_loader=data_loader, + optimizer=optimizer, + scaler=scaler, + lr_scheduler=lr_scheduler, + log_freq=log_freq, + cuda_enabled=cuda_enabled, + accum_grad_iters=accum_grad_iters, + ) + + def _train_inner_loop( + self, + epoch, + iters_per_epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + start_iters=None, + log_freq=50, + cuda_enabled=False, + accum_grad_iters=1, + ): + """ + An inner training loop compatible with both epoch-based and iter-based training. + + When using epoch-based, training stops after one epoch; when using iter-based, + training stops after #iters_per_epoch iterations. + """ + use_amp = scaler is not None + + if not hasattr(data_loader, "__next__"): + # convert to iterator if not already + data_loader = iter(data_loader) + + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) + + # if iter-based runner, schedule lr based on inner epoch. + logging.info( + "Start training epoch {}, {} iters per inner epoch.".format( + epoch, iters_per_epoch + ) + ) + header = "Train: data epoch: [{}]".format(epoch) + if start_iters is None: + # epoch-based runner + inner_epoch = epoch + else: + # In iter-based runner, we schedule the learning rate based on iterations. + inner_epoch = start_iters // iters_per_epoch + header = header + "; inner epoch [{}]".format(inner_epoch) + + for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): + # if using iter-based runner, we stop after iters_per_epoch iterations. + if i >= iters_per_epoch: + break + + samples = next(data_loader) + + samples = prepare_sample(samples, cuda_enabled=cuda_enabled) + samples.update( + { + "epoch": inner_epoch, + "num_iters_per_epoch": iters_per_epoch, + "iters": i, + } + ) + + lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) + + with torch.cuda.amp.autocast(enabled=use_amp): + loss = self.train_step(model=model, samples=samples) + + # after_train_step() + if use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + # update gradients every accum_grad_iters iterations + if (i + 1) % accum_grad_iters == 0: + if use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # after train_epoch() + # gather the stats from all processes + metric_logger.synchronize_between_processes() + logging.info("Averaged stats: " + str(metric_logger.global_avg())) + return { + k: "{:.3f}".format(meter.global_avg) + for k, meter in metric_logger.meters.items() + } + + @staticmethod + def save_result(result, result_dir, filename, remove_duplicate=""): + import json + + result_file = os.path.join( + result_dir, "%s_rank%d.json" % (filename, get_rank()) + ) + final_result_file = os.path.join(result_dir, "%s.json" % filename) + + json.dump(result, open(result_file, "w")) + + if is_dist_avail_and_initialized(): + dist.barrier() + + if is_main_process(): + logging.warning("rank %d starts merging results." % get_rank()) + # combine results from all processes + result = [] + + for rank in range(get_world_size()): + result_file = os.path.join( + result_dir, "%s_rank%d.json" % (filename, rank) + ) + res = json.load(open(result_file, "r")) + result += res + + if remove_duplicate: + result_new = [] + id_list = [] + for res in result: + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + json.dump(result, open(final_result_file, "w")) + print("result file saved to %s" % final_result_file) + + return final_result_file diff --git a/video_llama/tasks/image_text_pretrain.py b/video_llama/tasks/image_text_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..db955f27bb7dc8093cffd95b3a26917bb681c846 --- /dev/null +++ b/video_llama/tasks/image_text_pretrain.py @@ -0,0 +1,18 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from video_llama.common.registry import registry +from video_llama.tasks.base_task import BaseTask + + +@registry.register_task("image_text_pretrain") +class ImageTextPretrainTask(BaseTask): + def __init__(self): + super().__init__() + + def evaluation(self, model, data_loader, cuda_enabled=True): + pass diff --git a/video_llama/tasks/video_text_pretrain.py b/video_llama/tasks/video_text_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf72e5e878500e0cc3aa719c7cb20b56f63c71a --- /dev/null +++ b/video_llama/tasks/video_text_pretrain.py @@ -0,0 +1,18 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from video_llama.common.registry import registry +from video_llama.tasks.base_task import BaseTask + + +@registry.register_task("video_text_pretrain") +class VideoTextPretrainTask(BaseTask): + def __init__(self): + super().__init__() + + def evaluation(self, model, data_loader, cuda_enabled=True): + pass