leoxing1996
add demo
d16b52d
raw
history blame
1.11 kB
import torch
import torch.nn as nn
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
retrieve_latents,
)
from .builder import EngineBuilder
from .models import BaseModel
class TorchVAEEncoder(torch.nn.Module):
def __init__(self, vae: AutoencoderKL):
super().__init__()
self.vae = vae
def forward(self, x: torch.Tensor):
return retrieve_latents(self.vae.encode(x))
def compile_engine(
torch_model: nn.Module,
model_data: BaseModel,
onnx_path: str,
onnx_opt_path: str,
engine_path: str,
opt_image_height: int = 512,
opt_image_width: int = 512,
opt_batch_size: int = 1,
engine_build_options: dict = {},
):
builder = EngineBuilder(
model_data,
torch_model,
device=torch.device("cuda"),
)
builder.build(
onnx_path,
onnx_opt_path,
engine_path,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
opt_batch_size=opt_batch_size,
**engine_build_options,
)