Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
from wgpu.utils.shadertoy import * | |
from wgpu.gui.offscreen import WgpuCanvas as OffscreenCanvas, run as run_offscreen | |
import wgpu | |
import time | |
import ctypes | |
import datasets | |
from PIL import Image | |
import asyncio | |
import numpy as np | |
# reimplement the Shadertoy class with offscreen canvas! | |
class ShadertoyCustom(Shadertoy): | |
def __init__(self, shader_code, resolution=(800, 450), canvas_class=WgpuCanvas, run_fn=run): | |
self._canvas_class = canvas_class | |
self._fun_fn = run_fn | |
super().__init__(shader_code, resolution) | |
self._uniform_data = UniformArray( | |
("mouse", "f", 4), | |
("resolution", "f", 3), | |
("time", "f", 1), | |
("time_delta", "f", 1), | |
("frame", "I", 1), | |
) | |
self._shader_code = shader_code | |
self._uniform_data["resolution"] = resolution + (1,) | |
self._prepare_render() | |
self._bind_events() | |
def _prepare_render(self): | |
import wgpu.backends.rs # noqa | |
self._canvas = self._canvas_class(title="Shadertoy", size=self.resolution, max_fps=60) | |
adapter = wgpu.request_adapter( | |
canvas=self._canvas, power_preference="high-performance" | |
) | |
self._device = adapter.request_device() | |
self._present_context = self._canvas.get_context() | |
# We use "bgra8unorm" not "bgra8unorm-srgb" here because we want to let the shader fully control the color-space. | |
self._present_context.configure( | |
device=self._device, format=wgpu.TextureFormat.bgra8unorm | |
) | |
shader_type = self.shader_type | |
if shader_type == "glsl": | |
vertex_shader_code = vertex_code_glsl | |
frag_shader_code = ( | |
builtin_variables_glsl + self.shader_code + fragment_code_glsl | |
) | |
elif shader_type == "wgsl": | |
vertex_shader_code = vertex_code_wgsl | |
frag_shader_code = ( | |
builtin_variables_wgsl + self.shader_code + fragment_code_wgsl | |
) | |
vertex_shader_program = self._device.create_shader_module( | |
label="triangle_vert", code=vertex_shader_code | |
) | |
frag_shader_program = self._device.create_shader_module( | |
label="triangle_frag", code=frag_shader_code | |
) | |
self._uniform_buffer = self._device.create_buffer( | |
size=self._uniform_data.nbytes, | |
usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST, | |
) | |
bind_group_layout = self._device.create_bind_group_layout( | |
entries=binding_layout | |
) | |
self._bind_group = self._device.create_bind_group( | |
layout=bind_group_layout, | |
entries=[ | |
{ | |
"binding": 0, | |
"resource": { | |
"buffer": self._uniform_buffer, | |
"offset": 0, | |
"size": self._uniform_data.nbytes, | |
}, | |
}, | |
], | |
) | |
self._render_pipeline = self._device.create_render_pipeline( | |
layout=self._device.create_pipeline_layout( | |
bind_group_layouts=[bind_group_layout] | |
), | |
vertex={ | |
"module": vertex_shader_program, | |
"entry_point": "main", | |
"buffers": [], | |
}, | |
primitive={ | |
"topology": wgpu.PrimitiveTopology.triangle_list, | |
"front_face": wgpu.FrontFace.ccw, | |
"cull_mode": wgpu.CullMode.none, | |
}, | |
depth_stencil=None, | |
multisample=None, | |
fragment={ | |
"module": frag_shader_program, | |
"entry_point": "main", | |
"targets": [ | |
{ | |
"format": wgpu.TextureFormat.bgra8unorm, | |
"blend": { | |
"color": ( | |
wgpu.BlendFactor.one, | |
wgpu.BlendFactor.zero, | |
wgpu.BlendOperation.add, | |
), | |
"alpha": ( | |
wgpu.BlendFactor.one, | |
wgpu.BlendFactor.zero, | |
wgpu.BlendOperation.add, | |
), | |
}, | |
}, | |
], | |
}, | |
) | |
def show(self, time: float = 0.0): | |
self._canvas.request_draw(self._draw_frame) | |
self._fun_fn() | |
def make_script(shader_code): | |
# code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html | |
script = (""" | |
<!-- Licensed under a BSD license. See license.html for license --> | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="utf-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes"> | |
<title>WebGL - Shadertoy</title> | |
<link type="text/css" href="https://webglfundamentals.org/webgl/resources/webgl-tutorials.css" rel="stylesheet" /> | |
<style> | |
.divcanvas { | |
position: relative; | |
display: inline-block; | |
} | |
canvas { | |
display: block; | |
} | |
.playpause { | |
position: absolute; | |
left: 10px; | |
top: 10px; | |
width: 100%; | |
height: 100%; | |
font-size: 60px; | |
justify-content: center; | |
align-items: center; | |
color: rgba(255, 255, 255, 0.3); | |
transition: opacity 0.2s ease-in-out; | |
} | |
.playpausehide, | |
.playpause:hover { | |
opacity: 0; | |
} | |
.iframe .divcanvas { | |
display: block; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="divcanvas"> | |
<canvas id="canvas"></canvas> | |
<div class="playpause">▶</div> | |
</div> | |
blank canvas here indicates that some of the shadertoy specific functions are not yet supported with this implementation (like #define I believe). you can always copy and paste the code into a shadertoy.com window to try. | |
</body> | |
<!-- | |
for most samples webgl-utils only provides shader compiling/linking and | |
canvas resizing because why clutter the examples with code thats the same in every sample. | |
See https://webglfundamentals.org/webgl/lessons/webgl-boilerplate.html | |
and https://webglfundamentals.org/webgl/lessons/webgl-resizing-the-canvas.html | |
for webgl-utils, m3, m4, and webgl-lessons-ui. | |
--> | |
<script src="https://webglfundamentals.org/webgl/resources/webgl-utils.js"></script> | |
<script> | |
"use strict"; | |
function main() { | |
// Get A WebGL context | |
/** @type {HTMLCanvasElement} */ | |
const canvas = document.querySelector("#canvas"); | |
const gl = canvas.getContext("webgl"); | |
if (!gl) { | |
return; | |
} | |
const vs = ` | |
// an attribute will receive data from a buffer | |
attribute vec4 a_position; | |
// all shaders have a main function | |
void main() { | |
// gl_Position is a special variable a vertex shader | |
// is responsible for setting | |
gl_Position = a_position; | |
} | |
`; | |
const fs = ` | |
precision highp float; | |
uniform vec2 iResolution; | |
uniform vec2 iMouse; | |
uniform float iTime; | |
""" + shader_code + """ | |
void main() { | |
mainImage(gl_FragColor, gl_FragCoord.xy); | |
} | |
`; | |
// setup GLSL program | |
const program = webglUtils.createProgramFromSources(gl, [vs, fs]); | |
// look up where the vertex data needs to go. | |
const positionAttributeLocation = gl.getAttribLocation(program, "a_position"); | |
// look up uniform locations | |
const resolutionLocation = gl.getUniformLocation(program, "iResolution"); | |
const mouseLocation = gl.getUniformLocation(program, "iMouse"); | |
const timeLocation = gl.getUniformLocation(program, "iTime"); | |
// Create a buffer to put three 2d clip space points in | |
const positionBuffer = gl.createBuffer(); | |
// Bind it to ARRAY_BUFFER (think of it as ARRAY_BUFFER = positionBuffer) | |
gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer); | |
// fill it with a 2 triangles that cover clipspace | |
gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([ | |
-1, -1, // first triangle | |
1, -1, | |
-1, 1, | |
-1, 1, // second triangle | |
1, -1, | |
1, 1, | |
]), gl.STATIC_DRAW); | |
const playpauseElem = document.querySelector(".playpause"); | |
const inputElem = document.querySelector(".divcanvas"); | |
inputElem.addEventListener("mouseover", requestFrame); | |
inputElem.addEventListener("mouseout", cancelFrame); | |
let mouseX = 0; | |
let mouseY = 0; | |
function setMousePosition(e) { | |
const rect = inputElem.getBoundingClientRect(); | |
mouseX = e.clientX - rect.left; | |
mouseY = rect.height - (e.clientY - rect.top) - 1; // bottom is 0 in WebGL | |
} | |
inputElem.addEventListener("mousemove", setMousePosition); | |
inputElem.addEventListener("touchstart", (e) => { | |
e.preventDefault(); | |
playpauseElem.classList.add("playpausehide"); | |
requestFrame(); | |
}, {passive: false}); | |
inputElem.addEventListener("touchmove", (e) => { | |
e.preventDefault(); | |
setMousePosition(e.touches[0]); | |
}, {passive: false}); | |
inputElem.addEventListener("touchend", (e) => { | |
e.preventDefault(); | |
playpauseElem.classList.remove("playpausehide"); | |
cancelFrame(); | |
}, {passive: false}); | |
let requestId; | |
function requestFrame() { | |
if (!requestId) { | |
requestId = requestAnimationFrame(render); | |
} | |
} | |
function cancelFrame() { | |
if (requestId) { | |
cancelAnimationFrame(requestId); | |
requestId = undefined; | |
} | |
} | |
let then = 0; | |
let time = 0; | |
function render(now) { | |
requestId = undefined; | |
now *= 0.001; // convert to seconds | |
const elapsedTime = Math.min(now - then, 0.1); | |
time += elapsedTime; | |
then = now; | |
webglUtils.resizeCanvasToDisplaySize(gl.canvas); | |
// Tell WebGL how to convert from clip space to pixels | |
gl.viewport(0, 0, gl.canvas.width, gl.canvas.height); | |
// Tell it to use our program (pair of shaders) | |
gl.useProgram(program); | |
// Turn on the attribute | |
gl.enableVertexAttribArray(positionAttributeLocation); | |
// Bind the position buffer. | |
gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer); | |
// Tell the attribute how to get data out of positionBuffer (ARRAY_BUFFER) | |
gl.vertexAttribPointer( | |
positionAttributeLocation, | |
2, // 2 components per iteration | |
gl.FLOAT, // the data is 32bit floats | |
false, // dont normalize the data | |
0, // 0 = move forward size * sizeof(type) each iteration to get the next position | |
0, // start at the beginning of the buffer | |
); | |
gl.uniform2f(resolutionLocation, gl.canvas.width, gl.canvas.height); | |
gl.uniform2f(mouseLocation, mouseX, mouseY); | |
gl.uniform1f(timeLocation, time); | |
gl.drawArrays( | |
gl.TRIANGLES, | |
0, // offset | |
6, // num vertices to process | |
); | |
requestFrame(); | |
} | |
requestFrame(); | |
requestAnimationFrame(cancelFrame); | |
} | |
main(); | |
</script> | |
</html> | |
""") | |
return script | |
def make_iframe(shader_code): #keep a single function? | |
script = make_script(shader_code) | |
return f"""<iframe width="640" height="512" srcdoc=\'{script}\' allowfullscreen></iframe>""" | |
text = """ | |
# Welcome to the interactive shadercoding demo. | |
## (WIP), you can try and explore the dataset a bit right now. (frames are rendered on the fly, not part of the dataset(yet)) | |
This gives you access to a filtered version of the [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys) dataset, only shaders that const of a single pass (and have at least one fuction with a return statement) are available. | |
In the near future there will be some buttons and sliders to generate variations of the shadercode itself, and hence get some different images. | |
If I find an efficient way, the shaders might run in real time and be interactive. | |
## TODO: | |
- [x] use embedded Shadertoy for reference/attribution (done, but some errors) | |
- [] working render implementation on CPU only space (use the browser for WebGPU?, maybe via an iFrame too?) freespace uses lavapipe which works on really simple stuff. | |
- [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (needs to be reworked using the other parts) | |
- [] generation history stating which function and orig/generated returns. (use State ??). do it as comments in the code? | |
- [x] generate whole functions (seems to work quite well) | |
- [] display errros/issues to the user | |
- [] generate whole shaders (via prompts?) | |
- [] accordion with generation parameters (as pipeline_kwargs?) | |
""" | |
passes_dataset = datasets.load_dataset("Vipitis/Shadertoys") | |
single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions. | |
all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]]) | |
num_samples = len(all_single_passes) | |
import tree_sitter | |
from tree_sitter import Language, Parser | |
Language.build_library("./build/my-languages.so", ['tree-sitter-glsl']) | |
GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl') | |
parser = Parser() | |
parser.set_language(GLSL_LANGUAGE) | |
async def get_image(code, time= 0.0, resolution=(512, 420)): | |
tree = parser.parse(bytes(code, "utf8")) | |
if tree.root_node.has_error: | |
print("ERROR in the tree, aborting.") | |
raise gr.Error("the code seems to have issues") | |
return None | |
shader = ShadertoyCustom(code, resolution, OffscreenCanvas, run_offscreen) #pass offscreen canvas here. | |
shader._uniform_data["time"] = time #set any time you want | |
shader._canvas.request_draw(shader._draw_frame) | |
# frame = shader._canvas.snapshot().data | |
frame = np.asarray(shader._canvas.draw()) | |
img = Image.fromarray(frame) | |
# remove transparent pixels | |
img = img.convert('RGB') | |
return img | |
def grab_sample(sample_idx): | |
sample_pass = all_single_passes[sample_idx] | |
sample_code = sample_pass["code"] | |
sample_source = sample_pass["source"] | |
sample_title = sample_pass["title"] | |
sample_auhtor = sample_pass["author"] | |
source_iframe = construct_embed(sample_source) | |
print(f"{source_iframe=}") | |
# sample_funcs = _parse_functions(sample_code) | |
# funcs = _parse_functions(sample_code) | |
# func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)] | |
# print(f"updating drop down to:{func_identifiers}") | |
return sample_pass, sample_code, source_iframe, funcs#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor | |
def _parse_functions(in_code): | |
""" | |
returns all functions in the code as their actual nodes. | |
""" | |
tree = parser.parse(bytes(in_code, "utf8")) | |
funcs = [n for n in tree.root_node.children if n.type == "function_definition"] | |
return funcs | |
PIPE = None | |
def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing | |
tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True) | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True) | |
PIPE = pipe # set the global? | |
print(f"loaded model {model_cp} as a pipline") | |
return pipe | |
def process_retn(retn): | |
return retn.split(";")[0].strip() | |
def get_full_replacement(orig_code, retn_start_idx, retn_end_idx, prediction) -> str: | |
""" | |
Batches the generated return statement into the code and returns the full altered code. | |
""" | |
print(f"{orig_code[retn_start_idx:retn_end_idx]=}") | |
generated = process_retn(prediction) | |
print(f"{generated=}") | |
variation = orig_code[:retn_start_idx] + generated + orig_code[retn_end_idx:] | |
return variation | |
def alter_return(orig_code, func_idx=0, pipeline=PIPE): #default pipeline can't be passed as gloabl? | |
""" | |
Replaces the return statement of a function with a generated one. | |
Args: | |
orig_code (str): The original code. | |
func_idx (int): The index of the function to replace the return statement of. | |
pipeline (Pipeline): The pipeline to use for generation. | |
Returns: | |
str: The altered code. | |
""" | |
if pipeline is None: | |
print("no pipeline found, loading default one") | |
pipeline = _make_pipeline() | |
retrns = [] | |
retrn_start_idx = orig_code.find("return") | |
while retrn_start_idx != -1: | |
retrn_end_idx = orig_code.find(";", retrn_start_idx) | |
retrns.append((retrn_start_idx, retrn_end_idx)) | |
retrn_start_idx = orig_code.find("return", retrn_end_idx) | |
num_returns = len(retrns) | |
if num_returns == 0: | |
print("no return statement found, returning original code") | |
return orig_code | |
func_idx = int(max(0, min(func_idx, num_returns - 1))) #clamp to valid range, cast to int as a bodge. | |
retrn_start_idx, retrn_end_idx = retrns[func_idx] | |
model_context = orig_code[:retrn_start_idx] #TODO: maximal context? | |
model_inp = model_context + "return" | |
new_toks = (retrn_end_idx - retrn_start_idx) * 2 #TODO: approximation, we do have early stopping? maybe also use a number instead? | |
pipe_generation = pipeline(model_inp, max_new_tokens=new_toks, return_full_text=False)[0]["generated_text"] #pipeline kwargs are missing?! | |
altered_code = get_full_replacement(orig_code, retrn_start_idx+7, retrn_end_idx, pipe_generation) | |
return altered_code | |
def _line_chr2char(text, line_idx, chr_idx): | |
""" | |
returns the character index at the given line and character index. | |
""" | |
lines = text.split("\n") | |
char_idx = 0 | |
for i in range(line_idx): | |
char_idx += len(lines[i]) + 1 | |
char_idx += chr_idx | |
return char_idx | |
def alter_body(old_code, func_id, funcs_list, pipeline=PIPE): | |
""" | |
Replaces the body of a function with a generated one. | |
Args: | |
old_code (str): The original code. | |
func_node (Node): The node of the function to replace the body of. | |
pipeline (Pipeline): The pipeline to use for generation. | |
Returns: | |
str: The altered code. | |
""" | |
print(f"{func_id=}") | |
func_id = int(func_id.split(":")[0].strip()) #undo their string casting? | |
func_node = funcs_list[func_id] | |
print(f"using for generation: {func_node=}") | |
if pipeline is None: | |
print("no pipeline found, loading default one") | |
pipeline = _make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine") | |
func_start_idx = _line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1]) | |
identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() | |
body_node = func_node.child_by_field_name("body") | |
body_start_idx = _line_chr2char(old_code, body_node.start_point[0], body_node.start_point[1]) | |
body_end_idx = _line_chr2char(old_code, body_node.end_point[0], body_node.end_point[1]) | |
print(f"{old_code[body_start_idx:body_end_idx]=}") | |
model_context = identifier_str # just this | |
num_new_tokens = (body_end_idx - body_start_idx) + 10 #TODO: approximation, we do have early stopping? maybe also use a number instead? | |
print(f"generating up to {num_new_tokens} after {model_context!r}") | |
generation = pipeline(model_context, max_new_tokens=num_new_tokens, return_full_text=False)[0]["generated_text"] | |
print(f"{generation=}") | |
id_with_generation = identifier_str + generation | |
print(f"{id_with_generation=}") | |
first_gened_func = _parse_functions(id_with_generation)[0] # truncate generation to a single function? | |
# strip just the body. | |
print(f"{first_gened_func=}") | |
generated_body = first_gened_func.child_by_field_name("body").text.decode() | |
print(f"{generated_body=}") | |
altered_code = old_code[:body_start_idx] + generated_body + old_code[body_end_idx:] | |
return altered_code, pipeline | |
def add_history(func_id, orig_rtn, gened_rtn, history): | |
# is this a list? or a JSON dict? | |
history[func_id] = (orig_rtn, gened_rtn) | |
return history, history | |
def list_dropdown(in_code): #only used for auto update, not on sample pick? | |
funcs = _parse_functions(in_code) | |
# print(f"updating drop down to:{func_identifiers=}") | |
func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)] | |
# funcs = [n for n in funcs] #wrapped as set to avoid json issues? | |
print(f"updating drop down to:{func_identifiers}") | |
return funcs, gr.Dropdown.update(choices=func_identifiers) | |
def construct_embed(source_url): | |
shader_id = source_url.split("/")[-1] | |
return f'<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/{shader_id}?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>' | |
with gr.Blocks() as site: | |
text_md = gr.Markdown(text) | |
model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True) | |
sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0) | |
func_dropdown = gr.Dropdown(label="chose a function to modify") #breaks if I add a string in before that? | |
with gr.Row(): | |
gen_return_button = gr.Button("generate a alternate return statement", label="generate return") | |
gen_func_button = gr.Button("generate an alternate function body", label="generate function") | |
# update_funcs_button = gr.Button("update functions", label="update functions") | |
render_button = gr.Button("render frame0 (can carsh the sapce on invalid shadercode)",label="render frame") | |
time_slider = gr.Slider(minimum=0, maximum=10, value=0, label="time (update on release, also used to pick other functions as a bodge)", step=0.02) | |
with gr.Row(): | |
with gr.Column(): | |
source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/WsBcWV?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>', label="How this shader originally renders") | |
rendered_frame = gr.Image(shape=(512, 420), label=f"rendered frame preview", type="pil") #colors are messed up? | |
our_embed = gr.HTML(label="glsl render of the current code") | |
sample_code = gr.Code(label="Current Code (will update changes you generate)", language=None) | |
sample_pass = gr.State(value={}) | |
pipe = gr.State(value=PIPE) | |
funcs = gr.State(value=[]) | |
# hist_state = gr.State(Value={}) | |
# history_table = gr.JSON() | |
model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) | |
sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, source_embed]) | |
# sample_idx.release(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]) #use multiple event handles to call other functions! seems to not work really well. always messes up | |
gen_return_button.click(fn=alter_return, inputs=[sample_code, time_slider, pipe], outputs=[sample_code]) | |
gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, pipe], outputs=[sample_code, pipe]) | |
# run_button.click(fn=add_history, inputs=[time_slider, sample_pass, sample_code, hist_state], outputs=[history_table, hist_state]) | |
# sample_idx.release(fn=construct_embed, inputs=[sample_idx], outputs=[source_embed]) #twice to make have different outputs? | |
sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]) # to update this after generation, so spans aren't messed up | |
sample_code.change(fn=make_iframe, inputs=[sample_code], outputs=[our_embed]) #twice could cause issues, find better ways. | |
time_slider.release(fn=lambda code, time: asyncio.run(get_image(code, time)), inputs=[sample_code, time_slider], outputs=rendered_frame) | |
render_button.click(fn=lambda code: asyncio.run(get_image(code)), inputs=[sample_code], outputs=rendered_frame) | |
# run_button.click(fn=print, inputs=[model_cp, sample_idx], outputs=output) | |
site.launch() | |