Spaces:
Sleeping
Sleeping
import os | |
import csv | |
import json | |
import torch | |
import argparse | |
import pandas as pd | |
from tqdm import tqdm | |
from peft import LoraConfig, get_peft_model | |
from torch.utils.data import Dataset, DataLoader | |
from transformers.models.llama.tokenization_llama import LlamaTokenizer | |
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration | |
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_file', type = str, required = True, help = 'input csv file') | |
parser.add_argument('--output_file', type = str, help = 'output csv file') | |
parser.add_argument('--pretrained_ckpt', type = str, required = True, help = 'pretrained ckpt') | |
parser.add_argument('--trained_ckpt', type = str, help = 'trained ckpt') | |
parser.add_argument('--lora_r', type = int, default = 32) | |
parser.add_argument('--use_lora', action = 'store_true', help = 'lora model') | |
parser.add_argument('--all_params', action = 'store_true', help = 'all params') | |
parser.add_argument('--batch_size', type = int, default = 1) | |
parser.add_argument('--num_frames', type = int, default = 32) | |
args = parser.parse_args() | |
PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. | |
Human: <|video|> | |
Human: What is the misalignment between this video and the description: "{caption}"? | |
AI: ''' | |
generate_kwargs = { | |
'do_sample': True, | |
'top_k': 5, | |
'max_length': 512 | |
} | |
class VideoCaptionDataset(Dataset): | |
def __init__(self, input_file): | |
self.data = pd.read_csv(input_file) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
item = {} | |
item['videopath'] = self.data.iloc[index]['videopath'] | |
item['neg_caption'] = self.data.iloc[index]['neg_caption'] | |
return item | |
def get_nle(args, model, processor, tokenizer, dataloader): | |
with torch.no_grad(): | |
for _, batch in tqdm(enumerate(dataloader)): | |
videopaths = batch['videopath'] | |
neg_caption = batch['neg_caption'][0] | |
prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)] | |
inputs = processor(text=prompts, videos=videopaths, num_frames=args.num_frames, return_tensors='pt') | |
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
res = model.generate(**inputs, **generate_kwargs) | |
generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) | |
with open(args.output_file, 'a') as f: | |
writer = csv.writer(f) | |
writer.writerow([videopaths[0], neg_caption, generated_nle]) | |
def main(): | |
# Create dataloader | |
dataset = VideoCaptionDataset(args.input_file) | |
dataloader = DataLoader(dataset, batch_size = args.batch_size) | |
pretrained_ckpt = args.pretrained_ckpt | |
# Processors | |
tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt) | |
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) | |
processor = MplugOwlProcessor(image_processor, tokenizer) | |
# Instantiate model | |
model = MplugOwlForConditionalGeneration.from_pretrained( | |
pretrained_ckpt, | |
torch_dtype=torch.bfloat16, | |
device_map={'':0} | |
) | |
if args.use_lora: | |
for name, param in model.named_parameters(): | |
param.requires_grad = False | |
if args.all_params: | |
peft_config = LoraConfig( | |
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)', | |
inference_mode=True, | |
r=args.lora_r, | |
lora_alpha=16, | |
lora_dropout=0.05 | |
) | |
else: | |
peft_config = LoraConfig( | |
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj)', | |
inference_mode=True, | |
r=args.lora_r, | |
lora_alpha=16, | |
lora_dropout=0.05 | |
) | |
model = get_peft_model(model, peft_config) | |
model.print_trainable_parameters() | |
with open(args.trained_ckpt, 'rb') as f: | |
ckpt = torch.load(f, map_location = torch.device(f"cuda:0")) | |
model.load_state_dict(ckpt) | |
model = model.to(torch.bfloat16) | |
print('Model Loaded') | |
model.eval() | |
# get nle | |
get_nle(args, model, processor, tokenizer, dataloader) | |
if __name__ == "__main__": | |
main() |