MKFMIKU's picture
A new start
596242b
raw
history blame
6.31 kB
"""
HF_HOME=/mnt/store/kmei1/HF_HOME/ python gradio-app.py
"""
from omegaconf import OmegaConf
import torch
import torchvision
from diffusers.models import AutoencoderKL
from einops import rearrange
from models.t1 import Model as T1Model
from datasets.action_video import Dataset as ActionDataset
from diffusion import create_diffusion
import gradio as gr
import numpy as np
import PIL
from tqdm import tqdm
device = "cuda:6"
class ActionGameDemo:
nframes: int = 17
memory_frames: int = 16
def __init__(self, config_path, ckpt_path) -> None:
configs = OmegaConf.load(config_path)
configs.dataset.data_root = "demo"
model = T1Model(**configs.get("model", {}))
model.load_state_dict(torch.load(ckpt_path, map_location="cpu")["ema"])
model = model.to(device)
self.model = model
self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
dataset = ActionDataset(
**configs.get("dataset", {})
)
data = dataset[0]
x, action, pos = data
pos = torch.from_numpy(pos[None]).to(device)
x = torch.from_numpy(x[None]).to(device)
action = torch.tensor(action).to(device)
with torch.no_grad():
T = x.shape[2]
x = rearrange(x, "N C T H W -> (N T) C H W")
x = self.vae.encode(x).latent_dist.sample().mul_(0.13025)
x = rearrange(x, "(N T) C H W -> N C T H W", T=T)
self.all_frames = x
self.all_poses = pos
self.all_actions = action
self.diffusion = create_diffusion(
timestep_respacing="ddim20",
)
_video = x[0]
grid_h = np.arange(_video.shape[2] // 2, dtype=np.float32)
grid_w = np.arange(_video.shape[3] // 2, dtype=np.float32)
grid = np.meshgrid(grid_h, grid_w, indexing='ij') # here w goes first
self.grid = torch.from_numpy(np.stack(grid, axis=0)[None]).to(device)
def next_frame_predict(self, action):
z = torch.randn(1, 4, 1, 32, 32, device=device)
past_pos = torch.cat([
torch.tensor([i for i in range(self.memory_frames)], device=device)[None, None, :, None, None].expand(1, 1, self.memory_frames, *self.all_poses.shape[-2:]),
self.grid[:,:,None,:,:].expand(1, 2, self.memory_frames, *self.all_poses.shape[-2:]),
self.all_actions[None, None, -self.memory_frames:, None, None].expand(1, 1, self.memory_frames, *self.all_poses.shape[-2:]),
], dim=1)
pos = torch.cat([
torch.tensor([self.memory_frames], device=device)[None, None, :, None, None].expand(1, 1, 1, *self.all_poses.shape[-2:]),
self.grid[:,:,None,:,:].expand(1, 2, 1, *self.all_poses.shape[-2:]),
torch.tensor([action], device=device)[None, None, :, None, None].expand(1, 1, 1, *self.all_poses.shape[-2:]),
], dim=1)
model_kwargs = dict(
pos=pos,
past_frame=self.all_frames[:, :, -self.memory_frames:],
past_pos=past_pos
)
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16):
samples = self.diffusion.p_sample_loop(
self.model, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
)
self.all_frames = torch.cat([self.all_frames, samples], dim=2)
self.all_poses = torch.cat([self.all_poses, pos], dim=2)
self.all_actions = torch.cat([self.all_actions, torch.tensor([action], device=device)])
next_frame = self.vae.decode(samples[:, :, 0] / 0.13025).sample
next_frame = torch.clamp(next_frame, -1, 1).detach().cpu().numpy()
next_frame = rearrange(next_frame[0], "C H W -> H W C")
next_frame = PIL.Image.fromarray(np.uint8(255 * (next_frame / 2 + 0.5)))
return next_frame, "actions:" + str(self.all_actions.cpu().numpy())
def output_video(self):
_samples = rearrange(self.all_frames, "N C T H W -> (N T) C H W")
with torch.no_grad():
vidoes = []
for frame in _samples[::8]:
vidoes.append(self.vae.decode(frame.unsqueeze(0) / 0.13025).sample)
vidoes = torch.cat(vidoes)
vidoes = torch.clamp(vidoes, -1, 1)
del _samples
video = 255 * (vidoes.clip(-1, 1) / 2 + 0.5)
screeshoot = PIL.Image.fromarray(np.uint8(rearrange(video.cpu().numpy(), "T C H W -> H (T W) C")))
screeshoot.save("samples.png")
torchvision.io.write_video(
"samples.mp4",
video.permute(0, 2, 3, 1).cpu().numpy(),
fps=8,
video_codec="h264",
)
return "samples.mp4"
def create_interface(self):
with gr.Blocks() as demo:
gr.Markdown("# GenGame - OpenGenGAME/super-mario-bros-rl-1-1")
gr.Markdown("Try to control the game: Right, A, Right + A, Left")
with gr.Row():
image_output = gr.Image(label="Next Frame")
video_output = gr.Video(label="All Frames", autoplay=True, loop=True)
with gr.Row():
# left_btn = gr.Button("โ†")
right_btn = gr.Button("โ†’")
# a_btn = gr.Button("A")
right_a_btn = gr.Button("โ†’ A")
play_btn = gr.Button("conver video")
log_output = gr.Textbox(label="Actions Log", lines=5, interactive=False)
# sampling_steps = gr.Slider(minimum=10, maximum=50, value=10, step=1, label="Sampling Steps")
# left_btn.click(lambda: self.next_frame_predict(6), outputs=[image_output, log_output])
right_btn.click(lambda: self.next_frame_predict(3), outputs=[image_output, log_output])
# a_btn.click(lambda: self.next_frame_predict(5), outputs=[image_output, log_output])
right_a_btn.click(lambda: self.next_frame_predict(4), outputs=[image_output, log_output])
play_btn.click(self.output_video, outputs=video_output)
return demo
if __name__ == "__main__":
action_game = ActionGameDemo(
"configs/mario_t1_v0.yaml",
"results/002-t1/checkpoints/0100000.pt"
)
demo = action_game.create_interface()
demo.launch()