|
|
|
|
|
""" |
|
This file demonstrates an implementation of a multiprocess Real-time Long Video Understanding System. With a multiprocess logging module. |
|
main process: CLI server I/O, LLM inference |
|
process-1: logger listener |
|
process-2: frame generator, |
|
process-3: frame memory manager |
|
Author: Haoji Zhang, Haotian Liu |
|
(This code is based on https://github.com/haotian-liu/LLaVA) |
|
""" |
|
import argparse |
|
import requests |
|
import logging |
|
import torch |
|
import numpy as np |
|
import time |
|
import os |
|
|
|
from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
|
from flash_vstream.conversation import conv_templates, SeparatorStyle |
|
from flash_vstream.model.builder import load_pretrained_model |
|
from flash_vstream.utils import disable_torch_init |
|
from flash_vstream.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria |
|
|
|
from torch.multiprocessing import Process, Queue, Manager |
|
from transformers import TextStreamer |
|
from decord import VideoReader |
|
from datetime import datetime |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
class _Metric: |
|
def __init__(self): |
|
self._latest_value = None |
|
self._sum = 0.0 |
|
self._max = 0.0 |
|
self._count = 0 |
|
|
|
@property |
|
def val(self): |
|
return self._latest_value |
|
|
|
@property |
|
def max(self): |
|
return self._max |
|
|
|
@property |
|
def avg(self): |
|
if self._count == 0: |
|
return float('nan') |
|
return self._sum / self._count |
|
|
|
def add(self, value): |
|
self._latest_value = value |
|
self._sum += value |
|
self._count += 1 |
|
if value > self._max: |
|
self._max = value |
|
|
|
def __str__(self): |
|
latest_formatted = f"{self.val:.6f}" if self.val is not None else "None" |
|
average_formatted = f"{self.avg:.6f}" |
|
max_formatted = f"{self.max:.6f}" |
|
return f"{latest_formatted} ({average_formatted}, {max_formatted})" |
|
|
|
|
|
class MetricMeter: |
|
def __init__(self): |
|
self._metrics = {} |
|
|
|
def add(self, key, value): |
|
if key not in self._metrics: |
|
self._metrics[key] = _Metric() |
|
self._metrics[key].add(value) |
|
|
|
def val(self, key): |
|
metric = self._metrics.get(key) |
|
if metric is None or metric.val is None: |
|
raise ValueError(f"No values have been added for key '{key}'.") |
|
return metric.val |
|
|
|
def avg(self, key): |
|
metric = self._metrics.get(key) |
|
if metric is None: |
|
raise ValueError(f"No values have been added for key '{key}'.") |
|
return metric.avg |
|
|
|
def max(self, key): |
|
metric = self._metrics.get(key) |
|
if metric is None: |
|
raise ValueError(f"No values have been added for key '{key}'.") |
|
return metric.max |
|
|
|
def __getitem__(self, key): |
|
metric = self._metrics.get(key) |
|
if metric is None: |
|
raise KeyError(f"The key '{key}' does not exist.") |
|
return str(metric) |
|
|
|
def load_image(image_file): |
|
if image_file.startswith('http://') or image_file.startswith('https://'): |
|
response = requests.get(image_file) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
else: |
|
image = Image.open(image_file).convert('RGB') |
|
return image |
|
|
|
def listener(queue, filename): |
|
|
|
import sys, traceback |
|
root = logging.getLogger() |
|
root.setLevel(logging.DEBUG) |
|
|
|
h = logging.FileHandler(filename) |
|
f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s') |
|
h.setFormatter(f) |
|
root.addHandler(h) |
|
while True: |
|
try: |
|
record = queue.get() |
|
if record is None: |
|
break |
|
logger = logging.getLogger(record.name) |
|
logger.handle(record) |
|
except Exception: |
|
import sys, traceback |
|
print('Whoops! Problem:', file=sys.stderr) |
|
traceback.print_exc(file=sys.stderr) |
|
|
|
def worker_configurer(queue): |
|
h = logging.handlers.QueueHandler(queue) |
|
root = logging.getLogger() |
|
root.addHandler(h) |
|
root.setLevel(logging.DEBUG) |
|
|
|
def video_stream_similator(video_file, frame_queue, log_queue, video_fps=1.0, play_speed=1.0): |
|
|
|
worker_configurer(log_queue) |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.DEBUG) |
|
|
|
vr = VideoReader(video_file) |
|
sample_fps = round(vr.get_avg_fps() / video_fps) |
|
frame_idx = [i for i in range(0, len(vr), sample_fps)] |
|
video = vr.get_batch(frame_idx).asnumpy() |
|
video = np.repeat(video, 6, axis=0) |
|
length = video.shape[0] |
|
sleep_time = 1 / video_fps / play_speed |
|
time_meter = MetricMeter() |
|
logger.info(f'Simulator Process: start, length = {length}') |
|
try: |
|
for start in range(0, length): |
|
start_time = time.perf_counter() |
|
end = min(start + 1, length) |
|
video_clip = video[start:end] |
|
frame_queue.put(video_clip) |
|
if start > 0: |
|
time_meter.add('real_sleep', start_time - last_start) |
|
logger.info(f'Simulator: write {end - start} frames,\t{start} to {end},\treal_sleep={time_meter["real_sleep"]}') |
|
if end < length: |
|
time.sleep(sleep_time) |
|
last_start = start_time |
|
frame_queue.put(None) |
|
except Exception as e: |
|
print(f'Simulator Exception: {e}') |
|
time.sleep(0.1) |
|
logger.info(f'Simulator Process: end') |
|
|
|
def frame_memory_manager(model, image_processor, frame_queue, log_queue): |
|
|
|
worker_configurer(log_queue) |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.DEBUG) |
|
|
|
time_meter = MetricMeter() |
|
logger.info(f'MemManager Process: start') |
|
frame_cnt = 0 |
|
while True: |
|
try: |
|
video_clip = frame_queue.get() |
|
start_time = time.perf_counter() |
|
if video_clip is None: |
|
logger.info(f'MemManager: Ooops, get None') |
|
break |
|
logger.info(f'MemManager: get {video_clip.shape[0]} frames from queue') |
|
image = image_processor.preprocess(video_clip, return_tensors='pt')['pixel_values'] |
|
image = image.unsqueeze(0) |
|
image_tensor = image.to(model.device, dtype=torch.float16) |
|
|
|
logger.info(f'MemManager: Start embedding') |
|
with torch.inference_mode(): |
|
model.embed_video_streaming(image_tensor) |
|
logger.info(f'MemManager: End embedding') |
|
end_time = time.perf_counter() |
|
if frame_cnt > 0: |
|
time_meter.add('memory_latency', end_time - start_time) |
|
logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={time_meter["memory_latency"]}') |
|
else: |
|
logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={end_time - start_time:.6f}, not logged') |
|
frame_cnt += video_clip.shape[0] |
|
except Exception as e: |
|
print(f'MemManager Exception: {e}') |
|
time.sleep(0.1) |
|
logger.info(f'MemManager Process: end') |
|
|
|
def main(args): |
|
|
|
torch.multiprocessing.set_start_method('spawn', force=True) |
|
disable_torch_init() |
|
|
|
log_queue = Queue() |
|
frame_queue = Queue(maxsize=10) |
|
processes = [] |
|
|
|
|
|
p1 = Process(target=listener, args=(log_queue, args.log_file)) |
|
processes.append(p1) |
|
p1.start() |
|
|
|
|
|
worker_configurer(log_queue) |
|
logger = logging.getLogger(__name__) |
|
|
|
model_name = get_model_name_from_path(args.model_path) |
|
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) |
|
|
|
logger.info(f'Using conv_mode={args.conv_mode}') |
|
|
|
conv = conv_templates[args.conv_mode].copy() |
|
if "mpt" in model_name.lower(): |
|
roles = ('user', 'assistant') |
|
else: |
|
roles = conv.roles |
|
|
|
with Manager() as manager: |
|
image_tensor = None |
|
model.use_video_streaming_mode = True |
|
model.video_embedding_memory = manager.list() |
|
if args.video_max_frames is not None: |
|
model.config.video_max_frames = args.video_max_frames |
|
logger.info(f'Important: set model.config.video_max_frames = {model.config.video_max_frames}') |
|
|
|
logger.info(f'Important: set video_fps = {args.video_fps}') |
|
logger.info(f'Important: set play_speed = {args.play_speed}') |
|
|
|
|
|
p2 = Process(target=video_stream_similator, |
|
args=(args.video_file, frame_queue, log_queue, args.video_fps, args.play_speed)) |
|
processes.append(p2) |
|
p2.start() |
|
|
|
|
|
p3 = Process(target=frame_memory_manager, |
|
args=(model, image_processor, frame_queue, log_queue)) |
|
processes.append(p3) |
|
p3.start() |
|
|
|
|
|
start_time = datetime.now() |
|
time_meter = MetricMeter() |
|
conv_cnt = 0 |
|
while True: |
|
time.sleep(5) |
|
try: |
|
|
|
inp = "what is in the video?" |
|
except EOFError: |
|
inp = "" |
|
if not inp: |
|
print("exit...") |
|
break |
|
|
|
|
|
now = datetime.now() |
|
conv_start_time = time.perf_counter() |
|
|
|
current_time = now.strftime("%H:%M:%S") |
|
duration = now.timestamp() - start_time.timestamp() |
|
|
|
|
|
print("\nCurrent Time:", current_time, "Run for:", duration) |
|
print(f"{roles[0]}: {inp}", end="\n") |
|
print(f"{roles[1]}: ", end="") |
|
|
|
conv = conv_templates[args.conv_mode].copy() |
|
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp |
|
conv.append_message(conv.roles[0], inp) |
|
|
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) |
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
llm_start_time = time.perf_counter() |
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=True if args.temperature > 0 else False, |
|
temperature=args.temperature, |
|
max_new_tokens=args.max_new_tokens, |
|
streamer=streamer, |
|
use_cache=True, |
|
stopping_criteria=[stopping_criteria] |
|
) |
|
llm_end_time = time.perf_counter() |
|
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() |
|
conv.messages[-1][-1] = outputs |
|
conv_end_time = time.perf_counter() |
|
if conv_cnt > 0: |
|
time_meter.add('conv_latency', conv_end_time - conv_start_time) |
|
time_meter.add('llm_latency', llm_end_time - llm_start_time) |
|
time_meter.add('real_sleep', conv_start_time - last_conv_start_time) |
|
logger.info(f'CliServer: idx={conv_cnt},\treal_sleep={time_meter["real_sleep"]},\tconv_latency={time_meter["conv_latency"]},\tllm_latency={time_meter["llm_latency"]}') |
|
else: |
|
logger.info(f'CliServer: idx={conv_cnt},\tconv_latency={conv_end_time - conv_start_time},\tllm_latency={llm_end_time - llm_start_time}') |
|
conv_cnt += 1 |
|
last_conv_start_time = conv_start_time |
|
|
|
for p in processes: |
|
p.terminate() |
|
print("All processes finished.") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model-path", type=str, default="facebook/opt-350m") |
|
parser.add_argument("--model-base", type=str, default=None) |
|
parser.add_argument("--image-file", type=str, default=None) |
|
parser.add_argument("--video-file", type=str, default=None) |
|
parser.add_argument("--device", type=str, default="cuda") |
|
parser.add_argument("--conv-mode", type=str, default="vicuna_v1") |
|
parser.add_argument("--temperature", type=float, default=0.2) |
|
parser.add_argument("--max-new-tokens", type=int, default=512) |
|
parser.add_argument("--load-8bit", action="store_true") |
|
parser.add_argument("--load-4bit", action="store_true") |
|
parser.add_argument("--debug", action="store_true") |
|
|
|
parser.add_argument("--log-file", type=str, default="tmp_cli.log") |
|
parser.add_argument("--use_1process", action="store_true") |
|
parser.add_argument("--video_max_frames", type=int, default=None) |
|
parser.add_argument("--video_fps", type=float, default=1.0) |
|
parser.add_argument("--play_speed", type=float, default=1.0) |
|
args = parser.parse_args() |
|
main(args) |
|
|