Spaces:
Build error
Build error
File size: 6,308 Bytes
596242b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
"""
HF_HOME=/mnt/store/kmei1/HF_HOME/ python gradio-app.py
"""
from omegaconf import OmegaConf
import torch
import torchvision
from diffusers.models import AutoencoderKL
from einops import rearrange
from models.t1 import Model as T1Model
from datasets.action_video import Dataset as ActionDataset
from diffusion import create_diffusion
import gradio as gr
import numpy as np
import PIL
from tqdm import tqdm
device = "cuda:6"
class ActionGameDemo:
nframes: int = 17
memory_frames: int = 16
def __init__(self, config_path, ckpt_path) -> None:
configs = OmegaConf.load(config_path)
configs.dataset.data_root = "demo"
model = T1Model(**configs.get("model", {}))
model.load_state_dict(torch.load(ckpt_path, map_location="cpu")["ema"])
model = model.to(device)
self.model = model
self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
dataset = ActionDataset(
**configs.get("dataset", {})
)
data = dataset[0]
x, action, pos = data
pos = torch.from_numpy(pos[None]).to(device)
x = torch.from_numpy(x[None]).to(device)
action = torch.tensor(action).to(device)
with torch.no_grad():
T = x.shape[2]
x = rearrange(x, "N C T H W -> (N T) C H W")
x = self.vae.encode(x).latent_dist.sample().mul_(0.13025)
x = rearrange(x, "(N T) C H W -> N C T H W", T=T)
self.all_frames = x
self.all_poses = pos
self.all_actions = action
self.diffusion = create_diffusion(
timestep_respacing="ddim20",
)
_video = x[0]
grid_h = np.arange(_video.shape[2] // 2, dtype=np.float32)
grid_w = np.arange(_video.shape[3] // 2, dtype=np.float32)
grid = np.meshgrid(grid_h, grid_w, indexing='ij') # here w goes first
self.grid = torch.from_numpy(np.stack(grid, axis=0)[None]).to(device)
def next_frame_predict(self, action):
z = torch.randn(1, 4, 1, 32, 32, device=device)
past_pos = torch.cat([
torch.tensor([i for i in range(self.memory_frames)], device=device)[None, None, :, None, None].expand(1, 1, self.memory_frames, *self.all_poses.shape[-2:]),
self.grid[:,:,None,:,:].expand(1, 2, self.memory_frames, *self.all_poses.shape[-2:]),
self.all_actions[None, None, -self.memory_frames:, None, None].expand(1, 1, self.memory_frames, *self.all_poses.shape[-2:]),
], dim=1)
pos = torch.cat([
torch.tensor([self.memory_frames], device=device)[None, None, :, None, None].expand(1, 1, 1, *self.all_poses.shape[-2:]),
self.grid[:,:,None,:,:].expand(1, 2, 1, *self.all_poses.shape[-2:]),
torch.tensor([action], device=device)[None, None, :, None, None].expand(1, 1, 1, *self.all_poses.shape[-2:]),
], dim=1)
model_kwargs = dict(
pos=pos,
past_frame=self.all_frames[:, :, -self.memory_frames:],
past_pos=past_pos
)
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16):
samples = self.diffusion.p_sample_loop(
self.model, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
)
self.all_frames = torch.cat([self.all_frames, samples], dim=2)
self.all_poses = torch.cat([self.all_poses, pos], dim=2)
self.all_actions = torch.cat([self.all_actions, torch.tensor([action], device=device)])
next_frame = self.vae.decode(samples[:, :, 0] / 0.13025).sample
next_frame = torch.clamp(next_frame, -1, 1).detach().cpu().numpy()
next_frame = rearrange(next_frame[0], "C H W -> H W C")
next_frame = PIL.Image.fromarray(np.uint8(255 * (next_frame / 2 + 0.5)))
return next_frame, "actions:" + str(self.all_actions.cpu().numpy())
def output_video(self):
_samples = rearrange(self.all_frames, "N C T H W -> (N T) C H W")
with torch.no_grad():
vidoes = []
for frame in _samples[::8]:
vidoes.append(self.vae.decode(frame.unsqueeze(0) / 0.13025).sample)
vidoes = torch.cat(vidoes)
vidoes = torch.clamp(vidoes, -1, 1)
del _samples
video = 255 * (vidoes.clip(-1, 1) / 2 + 0.5)
screeshoot = PIL.Image.fromarray(np.uint8(rearrange(video.cpu().numpy(), "T C H W -> H (T W) C")))
screeshoot.save("samples.png")
torchvision.io.write_video(
"samples.mp4",
video.permute(0, 2, 3, 1).cpu().numpy(),
fps=8,
video_codec="h264",
)
return "samples.mp4"
def create_interface(self):
with gr.Blocks() as demo:
gr.Markdown("# GenGame - OpenGenGAME/super-mario-bros-rl-1-1")
gr.Markdown("Try to control the game: Right, A, Right + A, Left")
with gr.Row():
image_output = gr.Image(label="Next Frame")
video_output = gr.Video(label="All Frames", autoplay=True, loop=True)
with gr.Row():
# left_btn = gr.Button("←")
right_btn = gr.Button("→")
# a_btn = gr.Button("A")
right_a_btn = gr.Button("→ A")
play_btn = gr.Button("conver video")
log_output = gr.Textbox(label="Actions Log", lines=5, interactive=False)
# sampling_steps = gr.Slider(minimum=10, maximum=50, value=10, step=1, label="Sampling Steps")
# left_btn.click(lambda: self.next_frame_predict(6), outputs=[image_output, log_output])
right_btn.click(lambda: self.next_frame_predict(3), outputs=[image_output, log_output])
# a_btn.click(lambda: self.next_frame_predict(5), outputs=[image_output, log_output])
right_a_btn.click(lambda: self.next_frame_predict(4), outputs=[image_output, log_output])
play_btn.click(self.output_video, outputs=video_output)
return demo
if __name__ == "__main__":
action_game = ActionGameDemo(
"configs/mario_t1_v0.yaml",
"results/002-t1/checkpoints/0100000.pt"
)
demo = action_game.create_interface()
demo.launch()
|