Spaces:
Sleeping
Sleeping
File size: 4,643 Bytes
1040e55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import csv
import json
import torch
import argparse
import pandas as pd
from tqdm import tqdm
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
parser = argparse.ArgumentParser()
parser.add_argument('--input_file', type = str, required = True, help = 'input csv file')
parser.add_argument('--output_file', type = str, help = 'output csv file')
parser.add_argument('--pretrained_ckpt', type = str, required = True, help = 'pretrained ckpt')
parser.add_argument('--trained_ckpt', type = str, help = 'trained ckpt')
parser.add_argument('--lora_r', type = int, default = 32)
parser.add_argument('--use_lora', action = 'store_true', help = 'lora model')
parser.add_argument('--all_params', action = 'store_true', help = 'all params')
parser.add_argument('--batch_size', type = int, default = 1)
parser.add_argument('--num_frames', type = int, default = 32)
args = parser.parse_args()
PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
Human: <|video|>
Human: What is the misalignment between this video and the description: "{caption}"?
AI: '''
generate_kwargs = {
'do_sample': True,
'top_k': 5,
'max_length': 512
}
class VideoCaptionDataset(Dataset):
def __init__(self, input_file):
self.data = pd.read_csv(input_file)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = {}
item['videopath'] = self.data.iloc[index]['videopath']
item['neg_caption'] = self.data.iloc[index]['neg_caption']
return item
def get_nle(args, model, processor, tokenizer, dataloader):
with torch.no_grad():
for _, batch in tqdm(enumerate(dataloader)):
videopaths = batch['videopath']
neg_caption = batch['neg_caption'][0]
prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)]
inputs = processor(text=prompts, videos=videopaths, num_frames=args.num_frames, return_tensors='pt')
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
inputs = {k: v.to(model.device) for k, v in inputs.items()}
res = model.generate(**inputs, **generate_kwargs)
generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
with open(args.output_file, 'a') as f:
writer = csv.writer(f)
writer.writerow([videopaths[0], neg_caption, generated_nle])
def main():
# Create dataloader
dataset = VideoCaptionDataset(args.input_file)
dataloader = DataLoader(dataset, batch_size = args.batch_size)
pretrained_ckpt = args.pretrained_ckpt
# Processors
tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)
# Instantiate model
model = MplugOwlForConditionalGeneration.from_pretrained(
pretrained_ckpt,
torch_dtype=torch.bfloat16,
device_map={'':0}
)
if args.use_lora:
for name, param in model.named_parameters():
param.requires_grad = False
if args.all_params:
peft_config = LoraConfig(
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
inference_mode=True,
r=args.lora_r,
lora_alpha=16,
lora_dropout=0.05
)
else:
peft_config = LoraConfig(
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj)',
inference_mode=True,
r=args.lora_r,
lora_alpha=16,
lora_dropout=0.05
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
with open(args.trained_ckpt, 'rb') as f:
ckpt = torch.load(f, map_location = torch.device(f"cuda:0"))
model.load_state_dict(ckpt)
model = model.to(torch.bfloat16)
print('Model Loaded')
model.eval()
# get nle
get_nle(args, model, processor, tokenizer, dataloader)
if __name__ == "__main__":
main() |