jiaqianjing's picture
Update app.py
87b0a12 verified
import gradio as gr
import cv2
import matplotlib
import numpy as np
import os
from PIL import Image
import spaces
import torch
import tempfile
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from depth_anything_v2.dpt import DepthAnythingV2
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
encoder2name = {
'vits': 'Small',
'vitb': 'Base',
'vitl': 'Large',
'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
}
# Initialize model as None
model = None
def load_model(encoder):
global model
model_name = encoder2name[encoder]
model = DepthAnythingV2(**model_configs[encoder])
filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
return f"Loaded {model_name} model"
title = "# Depth Anything V2"
description = """Official demo for **Depth Anything V2**.
Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](https://depth-anything-v2.github.io), and [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
# @spaces.GPU
def predict_depth(image):
return model.infer_image(image)
def on_submit(image, encoder):
if model is None or model.encoder != encoder:
load_model(encoder)
original_image = image.copy()
h, w = image.shape[:2]
depth = predict_depth(image[:, :, ::-1])
raw_depth = Image.fromarray(depth.astype('uint16'))
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth.save(tmp_raw_depth.name)
# 归一化 depth 数值,使其范围对其标准的色彩空间 [0,255], 便于可视化显示(光谱图的方式呈现图片的深度信息)
# 这里使用 (depth - depth.min()) / (depth.max() - depth.min()) 确保最小值一定是 0,最大值一定是 1,整个集合在 [0,1];
# 使得在缩放后,较好保留 depth 的相对信息。尤其是 depth.min() >> 0 的时候,能够保持对比度。
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
gray_depth = Image.fromarray(depth)
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
gray_depth.save(tmp_gray_depth.name)
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown("### Depth Prediction demo")
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
# 选择模型
encoder_dropdown = gr.Dropdown(
choices=['vits', 'vitb', 'vitl', 'vitg'],
value='vits',
label="Select Encoder"
)
submit = gr.Button(value="Compute Depth")
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
# 获取 matplotlib 内置的颜色映射表:“Spectral_r”,用于将连续的 2D 数值通过渐变的颜色直观反映出来(光谱图的方式呈现图片的深度信息)
# "Spectral_r" 的特点:
# 颜色范围:从深红色到深蓝色,经过黄色和绿色
# 反转顺序:'_r' 表示颜色顺序是反转的(reversed)
# 高对比度:颜色变化明显,易于区分不同值
# 直观性:通常用暖色(红、橙、黄)表示高值,冷色(蓝、紫)表示低值
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
submit.click(on_submit, inputs=[input_image, encoder_dropdown], outputs=[depth_image_slider, gray_depth_file, raw_file])
example_files = os.listdir('assets/examples')
example_files.sort()
example_files = [[os.path.join('assets/examples', filename), "vits"] for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image, encoder_dropdown], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
if __name__ == '__main__':
demo.queue().launch(share=True)